Skip to main content

trustformers_debug/
attention_visualizer.rs

1//! Attention pattern visualization for transformer models
2//!
3//! This module provides tools to visualize and analyze attention patterns in transformer
4//! models, including multi-head attention, cross-attention, and self-attention mechanisms.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11// Note: ndarray types available for advanced array operations if needed
12
13/// Attention pattern visualizer for transformer models
14#[derive(Debug)]
15pub struct AttentionVisualizer {
16    /// Stored attention weights by layer and head
17    attention_weights: HashMap<String, AttentionWeights>,
18    /// Token vocabularies for labeling
19    token_vocab: Option<Vec<String>>,
20    /// Configuration
21    config: AttentionVisualizerConfig,
22}
23
24/// Configuration for attention visualization
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct AttentionVisualizerConfig {
27    /// Whether to normalize attention weights
28    pub normalize: bool,
29    /// Minimum attention weight to display (for filtering)
30    pub min_weight: f64,
31    /// Maximum number of tokens to visualize
32    pub max_tokens: usize,
33    /// Color scheme for visualization
34    pub color_scheme: ColorScheme,
35}
36
37/// Color scheme options for visualization
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum ColorScheme {
40    /// Blue to red gradient
41    BlueRed,
42    /// Grayscale
43    Grayscale,
44    /// Viridis color map
45    Viridis,
46    /// Plasma color map
47    Plasma,
48}
49
50impl Default for AttentionVisualizerConfig {
51    fn default() -> Self {
52        Self {
53            normalize: true,
54            min_weight: 0.01,
55            max_tokens: 512,
56            color_scheme: ColorScheme::BlueRed,
57        }
58    }
59}
60
61/// Attention weights for a specific layer
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct AttentionWeights {
64    /// Layer name
65    pub layer_name: String,
66    /// Number of attention heads
67    pub num_heads: usize,
68    /// Attention weights [num_heads, seq_len, seq_len]
69    pub weights: Vec<Vec<Vec<f64>>>,
70    /// Source tokens (query tokens)
71    pub source_tokens: Vec<String>,
72    /// Target tokens (key tokens)
73    pub target_tokens: Vec<String>,
74    /// Attention type
75    pub attention_type: AttentionType,
76}
77
78/// Type of attention mechanism
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum AttentionType {
81    /// Self-attention (source and target are the same)
82    SelfAttention,
83    /// Cross-attention (source and target are different)
84    CrossAttention,
85    /// Encoder-decoder attention
86    EncoderDecoderAttention,
87}
88
89/// Attention pattern analysis results
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct AttentionAnalysis {
92    /// Layer name
93    pub layer_name: String,
94    /// Average attention entropy per head
95    pub entropy_per_head: Vec<f64>,
96    /// Average attention sparsity per head
97    pub sparsity_per_head: Vec<f64>,
98    /// Most attended tokens (across all heads)
99    pub most_attended_tokens: Vec<(usize, f64)>,
100    /// Attention flow patterns
101    pub flow_patterns: Vec<AttentionFlow>,
102}
103
104/// Attention flow between token positions
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct AttentionFlow {
107    /// Source position
108    pub from: usize,
109    /// Target position
110    pub to: usize,
111    /// Attention weight
112    pub weight: f64,
113    /// Head index
114    pub head: usize,
115}
116
117/// Heatmap data for attention visualization
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct AttentionHeatmap {
120    /// Layer name
121    pub layer_name: String,
122    /// Head index
123    pub head: usize,
124    /// Attention weights matrix
125    pub weights: Vec<Vec<f64>>,
126    /// Row labels (query tokens)
127    pub row_labels: Vec<String>,
128    /// Column labels (key tokens)
129    pub col_labels: Vec<String>,
130}
131
132impl AttentionVisualizer {
133    /// Create a new attention visualizer
134    ///
135    /// # Example
136    ///
137    /// ```
138    /// use trustformers_debug::AttentionVisualizer;
139    ///
140    /// let visualizer = AttentionVisualizer::new();
141    /// ```
142    pub fn new() -> Self {
143        Self {
144            attention_weights: HashMap::new(),
145            token_vocab: None,
146            config: AttentionVisualizerConfig::default(),
147        }
148    }
149
150    /// Create a new attention visualizer with custom configuration
151    pub fn with_config(config: AttentionVisualizerConfig) -> Self {
152        Self {
153            attention_weights: HashMap::new(),
154            token_vocab: None,
155            config,
156        }
157    }
158
159    /// Set token vocabulary for labeling
160    pub fn set_token_vocab(&mut self, tokens: Vec<String>) {
161        self.token_vocab = Some(tokens);
162    }
163
164    /// Register attention weights for a layer
165    ///
166    /// # Arguments
167    ///
168    /// * `layer_name` - Name of the layer
169    /// * `weights` - Attention weights [num_heads, seq_len, seq_len]
170    /// * `source_tokens` - Source (query) tokens
171    /// * `target_tokens` - Target (key) tokens
172    /// * `attention_type` - Type of attention mechanism
173    ///
174    /// # Example
175    ///
176    /// ```
177    /// # use trustformers_debug::{AttentionVisualizer, AttentionType};
178    /// # let mut visualizer = AttentionVisualizer::new();
179    /// let weights = vec![
180    ///     vec![vec![0.5, 0.3, 0.2], vec![0.1, 0.6, 0.3], vec![0.2, 0.3, 0.5]]
181    /// ];
182    /// let tokens = vec!["Hello".to_string(), "world".to_string(), "!".to_string()];
183    ///
184    /// visualizer.register(
185    ///     "layer.0.attention",
186    ///     weights,
187    ///     tokens.clone(),
188    ///     tokens.clone(),
189    ///     AttentionType::SelfAttention,
190    /// ).unwrap();
191    /// ```
192    pub fn register(
193        &mut self,
194        layer_name: &str,
195        weights: Vec<Vec<Vec<f64>>>,
196        source_tokens: Vec<String>,
197        target_tokens: Vec<String>,
198        attention_type: AttentionType,
199    ) -> Result<()> {
200        let num_heads = weights.len();
201
202        let attention_weights = AttentionWeights {
203            layer_name: layer_name.to_string(),
204            num_heads,
205            weights,
206            source_tokens,
207            target_tokens,
208            attention_type,
209        };
210
211        self.attention_weights.insert(layer_name.to_string(), attention_weights);
212
213        Ok(())
214    }
215
216    /// Get attention weights for a specific layer
217    pub fn get_attention(&self, layer_name: &str) -> Option<&AttentionWeights> {
218        self.attention_weights.get(layer_name)
219    }
220
221    /// Create a heatmap for a specific head in a layer
222    pub fn create_heatmap(&self, layer_name: &str, head: usize) -> Result<AttentionHeatmap> {
223        let attention = self
224            .attention_weights
225            .get(layer_name)
226            .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
227
228        if head >= attention.num_heads {
229            anyhow::bail!(
230                "Head {} out of range (max: {})",
231                head,
232                attention.num_heads - 1
233            );
234        }
235
236        let weights = &attention.weights[head];
237
238        Ok(AttentionHeatmap {
239            layer_name: layer_name.to_string(),
240            head,
241            weights: weights.clone(),
242            row_labels: attention.source_tokens.clone(),
243            col_labels: attention.target_tokens.clone(),
244        })
245    }
246
247    /// Analyze attention patterns for a layer
248    pub fn analyze(&self, layer_name: &str) -> Result<AttentionAnalysis> {
249        let attention = self
250            .attention_weights
251            .get(layer_name)
252            .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
253
254        let entropy_per_head = attention
255            .weights
256            .iter()
257            .map(|head_weights| compute_entropy(head_weights))
258            .collect();
259
260        let sparsity_per_head = attention
261            .weights
262            .iter()
263            .map(|head_weights| compute_sparsity(head_weights, self.config.min_weight))
264            .collect();
265
266        let most_attended_tokens = find_most_attended_tokens(&attention.weights);
267
268        let flow_patterns = extract_attention_flows(&attention.weights, self.config.min_weight);
269
270        Ok(AttentionAnalysis {
271            layer_name: layer_name.to_string(),
272            entropy_per_head,
273            sparsity_per_head,
274            most_attended_tokens,
275            flow_patterns,
276        })
277    }
278
279    /// Plot attention heatmap as ASCII art
280    pub fn plot_heatmap_ascii(&self, layer_name: &str, head: usize) -> Result<String> {
281        let heatmap = self.create_heatmap(layer_name, head)?;
282
283        let mut output = String::new();
284        output.push_str(&format!(
285            "Attention Heatmap: {} (Head {})\n",
286            layer_name, head
287        ));
288        output.push_str(&"=".repeat(60));
289        output.push('\n');
290
291        // Limit display size for readability
292        let max_display = 20;
293        let display_rows = heatmap.row_labels.len().min(max_display);
294        let display_cols = heatmap.col_labels.len().min(max_display);
295
296        // Column headers
297        output.push_str("        ");
298        for col in 0..display_cols {
299            output.push_str(&format!(
300                "{:4}",
301                heatmap.col_labels[col].chars().next().unwrap_or('?')
302            ));
303        }
304        output.push('\n');
305
306        // Rows with values
307        for row in 0..display_rows {
308            let label = &heatmap.row_labels[row];
309            output.push_str(&format!(
310                "{:6}  ",
311                label.chars().take(6).collect::<String>()
312            ));
313
314            for col in 0..display_cols {
315                let weight = heatmap.weights[row][col];
316                let symbol = weight_to_symbol(weight);
317                output.push_str(&format!("{:4}", symbol));
318            }
319            output.push('\n');
320        }
321
322        if display_rows < heatmap.row_labels.len() || display_cols < heatmap.col_labels.len() {
323            output.push_str(&format!(
324                "\n(Showing {}/{} rows, {}/{} cols)\n",
325                display_rows,
326                heatmap.row_labels.len(),
327                display_cols,
328                heatmap.col_labels.len()
329            ));
330        }
331
332        Ok(output)
333    }
334
335    /// Export attention weights to JSON
336    pub fn export_to_json(&self, layer_name: &str, output_path: &Path) -> Result<()> {
337        let attention = self
338            .attention_weights
339            .get(layer_name)
340            .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
341
342        let json = serde_json::to_string_pretty(attention)?;
343        std::fs::write(output_path, json)?;
344
345        Ok(())
346    }
347
348    /// Export to BertViz-compatible format (HTML)
349    pub fn export_to_bertviz(&self, layer_name: &str, output_path: &Path) -> Result<()> {
350        let attention = self
351            .attention_weights
352            .get(layer_name)
353            .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
354
355        let mut html =
356            String::from("<html><head><title>Attention Visualization</title></head><body>");
357        html.push_str(&format!("<h1>{}</h1>", layer_name));
358
359        for head in 0..attention.num_heads {
360            html.push_str(&format!("<h2>Head {}</h2>", head));
361            html.push_str("<table border='1'><tr><th></th>");
362
363            // Column headers
364            for token in &attention.target_tokens {
365                html.push_str(&format!("<th>{}</th>", html_escape(token)));
366            }
367            html.push_str("</tr>");
368
369            // Rows
370            for (row_idx, source_token) in attention.source_tokens.iter().enumerate() {
371                html.push_str(&format!("<tr><th>{}</th>", html_escape(source_token)));
372
373                for col_idx in 0..attention.target_tokens.len() {
374                    let weight = attention.weights[head][row_idx][col_idx];
375                    let color = weight_to_color(weight);
376                    html.push_str(&format!(
377                        "<td style='background-color: {}'>{:.3}</td>",
378                        color, weight
379                    ));
380                }
381                html.push_str("</tr>");
382            }
383
384            html.push_str("</table>");
385        }
386
387        html.push_str("</body></html>");
388        std::fs::write(output_path, html)?;
389
390        Ok(())
391    }
392
393    /// Get summary statistics for all layers
394    pub fn summary(&self) -> String {
395        let mut output = String::new();
396        output.push_str("Attention Summary\n");
397        output.push_str(&"=".repeat(80));
398        output.push('\n');
399
400        for (layer_name, attention) in &self.attention_weights {
401            output.push_str(&format!("\nLayer: {}\n", layer_name));
402            output.push_str(&format!("  Num Heads: {}\n", attention.num_heads));
403            output.push_str(&format!(
404                "  Seq Length: {}\n",
405                attention.source_tokens.len()
406            ));
407            output.push_str(&format!(
408                "  Attention Type: {:?}\n",
409                attention.attention_type
410            ));
411
412            if let Ok(analysis) = self.analyze(layer_name) {
413                output.push_str(&format!(
414                    "  Avg Entropy: {:.4}\n",
415                    analysis.entropy_per_head.iter().sum::<f64>()
416                        / analysis.entropy_per_head.len() as f64
417                ));
418                output.push_str(&format!(
419                    "  Avg Sparsity: {:.4}\n",
420                    analysis.sparsity_per_head.iter().sum::<f64>()
421                        / analysis.sparsity_per_head.len() as f64
422                ));
423            }
424        }
425
426        output
427    }
428
429    /// Clear all stored attention weights
430    pub fn clear(&mut self) {
431        self.attention_weights.clear();
432    }
433
434    /// Get number of stored layers
435    pub fn num_layers(&self) -> usize {
436        self.attention_weights.len()
437    }
438}
439
440impl Default for AttentionVisualizer {
441    fn default() -> Self {
442        Self::new()
443    }
444}
445
446// Helper functions
447
448/// Compute entropy of attention distribution
449fn compute_entropy(weights: &[Vec<f64>]) -> f64 {
450    let mut total_entropy = 0.0;
451    let mut count = 0;
452
453    for row in weights {
454        let sum: f64 = row.iter().sum();
455        if sum > 0.0 {
456            let entropy: f64 = row
457                .iter()
458                .filter(|&&w| w > 0.0)
459                .map(|&w| {
460                    let p = w / sum;
461                    -p * p.log2()
462                })
463                .sum();
464            total_entropy += entropy;
465            count += 1;
466        }
467    }
468
469    if count > 0 {
470        total_entropy / count as f64
471    } else {
472        0.0
473    }
474}
475
476/// Compute sparsity (fraction of weights below threshold)
477fn compute_sparsity(weights: &[Vec<f64>], threshold: f64) -> f64 {
478    let total_weights: usize = weights.iter().map(|row| row.len()).sum();
479    let sparse_weights: usize =
480        weights.iter().map(|row| row.iter().filter(|&&w| w < threshold).count()).sum();
481
482    if total_weights > 0 {
483        sparse_weights as f64 / total_weights as f64
484    } else {
485        0.0
486    }
487}
488
489/// Find most attended token positions
490fn find_most_attended_tokens(weights: &[Vec<Vec<f64>>]) -> Vec<(usize, f64)> {
491    let seq_len = if !weights.is_empty() && !weights[0].is_empty() {
492        weights[0][0].len()
493    } else {
494        return Vec::new();
495    };
496
497    let mut token_attention = vec![0.0; seq_len];
498
499    for head_weights in weights {
500        for row in head_weights {
501            for (i, &weight) in row.iter().enumerate() {
502                token_attention[i] += weight;
503            }
504        }
505    }
506
507    let mut indexed: Vec<_> = token_attention.iter().enumerate().map(|(i, &w)| (i, w)).collect();
508    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
509
510    indexed.into_iter().take(10).collect()
511}
512
513/// Extract significant attention flows
514fn extract_attention_flows(weights: &[Vec<Vec<f64>>], threshold: f64) -> Vec<AttentionFlow> {
515    let mut flows = Vec::new();
516
517    for (head, head_weights) in weights.iter().enumerate() {
518        for (from, row) in head_weights.iter().enumerate() {
519            for (to, &weight) in row.iter().enumerate() {
520                if weight >= threshold {
521                    flows.push(AttentionFlow {
522                        from,
523                        to,
524                        weight,
525                        head,
526                    });
527                }
528            }
529        }
530    }
531
532    flows.sort_by(|a, b| b.weight.partial_cmp(&a.weight).unwrap_or(std::cmp::Ordering::Equal));
533    flows.into_iter().take(100).collect()
534}
535
536/// Convert attention weight to ASCII symbol
537fn weight_to_symbol(weight: f64) -> &'static str {
538    if weight > 0.8 {
539        "█"
540    } else if weight > 0.6 {
541        "▓"
542    } else if weight > 0.4 {
543        "▒"
544    } else if weight > 0.2 {
545        "░"
546    } else {
547        " "
548    }
549}
550
551/// Convert attention weight to HTML color
552fn weight_to_color(weight: f64) -> String {
553    let intensity = (weight * 255.0) as u8;
554    format!("rgb(255, {}, {})", 255 - intensity, 255 - intensity)
555}
556
557/// Escape HTML special characters
558fn html_escape(s: &str) -> String {
559    s.replace('&', "&amp;")
560        .replace('<', "&lt;")
561        .replace('>', "&gt;")
562        .replace('"', "&quot;")
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    #[test]
570    fn test_attention_visualizer_creation() {
571        let visualizer = AttentionVisualizer::new();
572        assert_eq!(visualizer.num_layers(), 0);
573    }
574
575    #[test]
576    fn test_register_attention() {
577        let mut visualizer = AttentionVisualizer::new();
578
579        let weights = vec![vec![
580            vec![0.5, 0.3, 0.2],
581            vec![0.1, 0.6, 0.3],
582            vec![0.2, 0.3, 0.5],
583        ]];
584
585        let tokens = vec!["A".to_string(), "B".to_string(), "C".to_string()];
586
587        visualizer
588            .register(
589                "layer.0",
590                weights,
591                tokens.clone(),
592                tokens,
593                AttentionType::SelfAttention,
594            )
595            .expect("operation failed in test");
596
597        assert_eq!(visualizer.num_layers(), 1);
598    }
599
600    #[test]
601    fn test_create_heatmap() {
602        let mut visualizer = AttentionVisualizer::new();
603
604        let weights = vec![vec![
605            vec![0.5, 0.3, 0.2],
606            vec![0.1, 0.6, 0.3],
607            vec![0.2, 0.3, 0.5],
608        ]];
609
610        let tokens = vec!["A".to_string(), "B".to_string(), "C".to_string()];
611
612        visualizer
613            .register(
614                "layer.0",
615                weights,
616                tokens.clone(),
617                tokens,
618                AttentionType::SelfAttention,
619            )
620            .expect("operation failed in test");
621
622        let heatmap = visualizer.create_heatmap("layer.0", 0).expect("operation failed in test");
623        assert_eq!(heatmap.layer_name, "layer.0");
624        assert_eq!(heatmap.head, 0);
625        assert_eq!(heatmap.weights.len(), 3);
626    }
627
628    #[test]
629    fn test_analyze_attention() {
630        let mut visualizer = AttentionVisualizer::new();
631
632        let weights = vec![vec![
633            vec![0.7, 0.2, 0.1],
634            vec![0.1, 0.8, 0.1],
635            vec![0.1, 0.1, 0.8],
636        ]];
637
638        let tokens = vec!["A".to_string(), "B".to_string(), "C".to_string()];
639
640        visualizer
641            .register(
642                "layer.0",
643                weights,
644                tokens.clone(),
645                tokens,
646                AttentionType::SelfAttention,
647            )
648            .expect("operation failed in test");
649
650        let analysis = visualizer.analyze("layer.0").expect("operation failed in test");
651        assert_eq!(analysis.entropy_per_head.len(), 1);
652        assert_eq!(analysis.sparsity_per_head.len(), 1);
653        assert!(!analysis.most_attended_tokens.is_empty());
654    }
655
656    #[test]
657    fn test_export_to_json() {
658        use std::env;
659
660        let temp_dir = env::temp_dir();
661        let output_path = temp_dir.join("attention.json");
662
663        let mut visualizer = AttentionVisualizer::new();
664        let weights = vec![vec![vec![1.0]]];
665        let tokens = vec!["A".to_string()];
666
667        visualizer
668            .register(
669                "layer.0",
670                weights,
671                tokens.clone(),
672                tokens,
673                AttentionType::SelfAttention,
674            )
675            .expect("operation failed in test");
676
677        visualizer
678            .export_to_json("layer.0", &output_path)
679            .expect("operation failed in test");
680        assert!(output_path.exists());
681
682        // Clean up
683        let _ = std::fs::remove_file(output_path);
684    }
685
686    #[test]
687    fn test_compute_entropy() {
688        let weights = vec![vec![0.5, 0.3, 0.2], vec![1.0, 0.0, 0.0]];
689
690        let entropy = compute_entropy(&weights);
691        assert!(entropy > 0.0);
692    }
693
694    #[test]
695    fn test_compute_sparsity() {
696        let weights = vec![vec![0.9, 0.05, 0.05], vec![0.01, 0.01, 0.98]];
697
698        let sparsity = compute_sparsity(&weights, 0.1);
699        assert!(sparsity > 0.0);
700        assert!(sparsity <= 1.0);
701    }
702}