1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10
11#[derive(Debug)]
15pub struct AttentionVisualizer {
16 attention_weights: HashMap<String, AttentionWeights>,
18 token_vocab: Option<Vec<String>>,
20 config: AttentionVisualizerConfig,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct AttentionVisualizerConfig {
27 pub normalize: bool,
29 pub min_weight: f64,
31 pub max_tokens: usize,
33 pub color_scheme: ColorScheme,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
39pub enum ColorScheme {
40 BlueRed,
42 Grayscale,
44 Viridis,
46 Plasma,
48}
49
50impl Default for AttentionVisualizerConfig {
51 fn default() -> Self {
52 Self {
53 normalize: true,
54 min_weight: 0.01,
55 max_tokens: 512,
56 color_scheme: ColorScheme::BlueRed,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct AttentionWeights {
64 pub layer_name: String,
66 pub num_heads: usize,
68 pub weights: Vec<Vec<Vec<f64>>>,
70 pub source_tokens: Vec<String>,
72 pub target_tokens: Vec<String>,
74 pub attention_type: AttentionType,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum AttentionType {
81 SelfAttention,
83 CrossAttention,
85 EncoderDecoderAttention,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct AttentionAnalysis {
92 pub layer_name: String,
94 pub entropy_per_head: Vec<f64>,
96 pub sparsity_per_head: Vec<f64>,
98 pub most_attended_tokens: Vec<(usize, f64)>,
100 pub flow_patterns: Vec<AttentionFlow>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct AttentionFlow {
107 pub from: usize,
109 pub to: usize,
111 pub weight: f64,
113 pub head: usize,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct AttentionHeatmap {
120 pub layer_name: String,
122 pub head: usize,
124 pub weights: Vec<Vec<f64>>,
126 pub row_labels: Vec<String>,
128 pub col_labels: Vec<String>,
130}
131
132impl AttentionVisualizer {
133 pub fn new() -> Self {
143 Self {
144 attention_weights: HashMap::new(),
145 token_vocab: None,
146 config: AttentionVisualizerConfig::default(),
147 }
148 }
149
150 pub fn with_config(config: AttentionVisualizerConfig) -> Self {
152 Self {
153 attention_weights: HashMap::new(),
154 token_vocab: None,
155 config,
156 }
157 }
158
159 pub fn set_token_vocab(&mut self, tokens: Vec<String>) {
161 self.token_vocab = Some(tokens);
162 }
163
164 pub fn register(
193 &mut self,
194 layer_name: &str,
195 weights: Vec<Vec<Vec<f64>>>,
196 source_tokens: Vec<String>,
197 target_tokens: Vec<String>,
198 attention_type: AttentionType,
199 ) -> Result<()> {
200 let num_heads = weights.len();
201
202 let attention_weights = AttentionWeights {
203 layer_name: layer_name.to_string(),
204 num_heads,
205 weights,
206 source_tokens,
207 target_tokens,
208 attention_type,
209 };
210
211 self.attention_weights.insert(layer_name.to_string(), attention_weights);
212
213 Ok(())
214 }
215
216 pub fn get_attention(&self, layer_name: &str) -> Option<&AttentionWeights> {
218 self.attention_weights.get(layer_name)
219 }
220
221 pub fn create_heatmap(&self, layer_name: &str, head: usize) -> Result<AttentionHeatmap> {
223 let attention = self
224 .attention_weights
225 .get(layer_name)
226 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
227
228 if head >= attention.num_heads {
229 anyhow::bail!(
230 "Head {} out of range (max: {})",
231 head,
232 attention.num_heads - 1
233 );
234 }
235
236 let weights = &attention.weights[head];
237
238 Ok(AttentionHeatmap {
239 layer_name: layer_name.to_string(),
240 head,
241 weights: weights.clone(),
242 row_labels: attention.source_tokens.clone(),
243 col_labels: attention.target_tokens.clone(),
244 })
245 }
246
247 pub fn analyze(&self, layer_name: &str) -> Result<AttentionAnalysis> {
249 let attention = self
250 .attention_weights
251 .get(layer_name)
252 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
253
254 let entropy_per_head = attention
255 .weights
256 .iter()
257 .map(|head_weights| compute_entropy(head_weights))
258 .collect();
259
260 let sparsity_per_head = attention
261 .weights
262 .iter()
263 .map(|head_weights| compute_sparsity(head_weights, self.config.min_weight))
264 .collect();
265
266 let most_attended_tokens = find_most_attended_tokens(&attention.weights);
267
268 let flow_patterns = extract_attention_flows(&attention.weights, self.config.min_weight);
269
270 Ok(AttentionAnalysis {
271 layer_name: layer_name.to_string(),
272 entropy_per_head,
273 sparsity_per_head,
274 most_attended_tokens,
275 flow_patterns,
276 })
277 }
278
279 pub fn plot_heatmap_ascii(&self, layer_name: &str, head: usize) -> Result<String> {
281 let heatmap = self.create_heatmap(layer_name, head)?;
282
283 let mut output = String::new();
284 output.push_str(&format!(
285 "Attention Heatmap: {} (Head {})\n",
286 layer_name, head
287 ));
288 output.push_str(&"=".repeat(60));
289 output.push('\n');
290
291 let max_display = 20;
293 let display_rows = heatmap.row_labels.len().min(max_display);
294 let display_cols = heatmap.col_labels.len().min(max_display);
295
296 output.push_str(" ");
298 for col in 0..display_cols {
299 output.push_str(&format!(
300 "{:4}",
301 heatmap.col_labels[col].chars().next().unwrap_or('?')
302 ));
303 }
304 output.push('\n');
305
306 for row in 0..display_rows {
308 let label = &heatmap.row_labels[row];
309 output.push_str(&format!(
310 "{:6} ",
311 label.chars().take(6).collect::<String>()
312 ));
313
314 for col in 0..display_cols {
315 let weight = heatmap.weights[row][col];
316 let symbol = weight_to_symbol(weight);
317 output.push_str(&format!("{:4}", symbol));
318 }
319 output.push('\n');
320 }
321
322 if display_rows < heatmap.row_labels.len() || display_cols < heatmap.col_labels.len() {
323 output.push_str(&format!(
324 "\n(Showing {}/{} rows, {}/{} cols)\n",
325 display_rows,
326 heatmap.row_labels.len(),
327 display_cols,
328 heatmap.col_labels.len()
329 ));
330 }
331
332 Ok(output)
333 }
334
335 pub fn export_to_json(&self, layer_name: &str, output_path: &Path) -> Result<()> {
337 let attention = self
338 .attention_weights
339 .get(layer_name)
340 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
341
342 let json = serde_json::to_string_pretty(attention)?;
343 std::fs::write(output_path, json)?;
344
345 Ok(())
346 }
347
348 pub fn export_to_bertviz(&self, layer_name: &str, output_path: &Path) -> Result<()> {
350 let attention = self
351 .attention_weights
352 .get(layer_name)
353 .ok_or_else(|| anyhow::anyhow!("Layer {} not found", layer_name))?;
354
355 let mut html =
356 String::from("<html><head><title>Attention Visualization</title></head><body>");
357 html.push_str(&format!("<h1>{}</h1>", layer_name));
358
359 for head in 0..attention.num_heads {
360 html.push_str(&format!("<h2>Head {}</h2>", head));
361 html.push_str("<table border='1'><tr><th></th>");
362
363 for token in &attention.target_tokens {
365 html.push_str(&format!("<th>{}</th>", html_escape(token)));
366 }
367 html.push_str("</tr>");
368
369 for (row_idx, source_token) in attention.source_tokens.iter().enumerate() {
371 html.push_str(&format!("<tr><th>{}</th>", html_escape(source_token)));
372
373 for col_idx in 0..attention.target_tokens.len() {
374 let weight = attention.weights[head][row_idx][col_idx];
375 let color = weight_to_color(weight);
376 html.push_str(&format!(
377 "<td style='background-color: {}'>{:.3}</td>",
378 color, weight
379 ));
380 }
381 html.push_str("</tr>");
382 }
383
384 html.push_str("</table>");
385 }
386
387 html.push_str("</body></html>");
388 std::fs::write(output_path, html)?;
389
390 Ok(())
391 }
392
393 pub fn summary(&self) -> String {
395 let mut output = String::new();
396 output.push_str("Attention Summary\n");
397 output.push_str(&"=".repeat(80));
398 output.push('\n');
399
400 for (layer_name, attention) in &self.attention_weights {
401 output.push_str(&format!("\nLayer: {}\n", layer_name));
402 output.push_str(&format!(" Num Heads: {}\n", attention.num_heads));
403 output.push_str(&format!(
404 " Seq Length: {}\n",
405 attention.source_tokens.len()
406 ));
407 output.push_str(&format!(
408 " Attention Type: {:?}\n",
409 attention.attention_type
410 ));
411
412 if let Ok(analysis) = self.analyze(layer_name) {
413 output.push_str(&format!(
414 " Avg Entropy: {:.4}\n",
415 analysis.entropy_per_head.iter().sum::<f64>()
416 / analysis.entropy_per_head.len() as f64
417 ));
418 output.push_str(&format!(
419 " Avg Sparsity: {:.4}\n",
420 analysis.sparsity_per_head.iter().sum::<f64>()
421 / analysis.sparsity_per_head.len() as f64
422 ));
423 }
424 }
425
426 output
427 }
428
429 pub fn clear(&mut self) {
431 self.attention_weights.clear();
432 }
433
434 pub fn num_layers(&self) -> usize {
436 self.attention_weights.len()
437 }
438}
439
440impl Default for AttentionVisualizer {
441 fn default() -> Self {
442 Self::new()
443 }
444}
445
446fn compute_entropy(weights: &[Vec<f64>]) -> f64 {
450 let mut total_entropy = 0.0;
451 let mut count = 0;
452
453 for row in weights {
454 let sum: f64 = row.iter().sum();
455 if sum > 0.0 {
456 let entropy: f64 = row
457 .iter()
458 .filter(|&&w| w > 0.0)
459 .map(|&w| {
460 let p = w / sum;
461 -p * p.log2()
462 })
463 .sum();
464 total_entropy += entropy;
465 count += 1;
466 }
467 }
468
469 if count > 0 {
470 total_entropy / count as f64
471 } else {
472 0.0
473 }
474}
475
476fn compute_sparsity(weights: &[Vec<f64>], threshold: f64) -> f64 {
478 let total_weights: usize = weights.iter().map(|row| row.len()).sum();
479 let sparse_weights: usize =
480 weights.iter().map(|row| row.iter().filter(|&&w| w < threshold).count()).sum();
481
482 if total_weights > 0 {
483 sparse_weights as f64 / total_weights as f64
484 } else {
485 0.0
486 }
487}
488
489fn find_most_attended_tokens(weights: &[Vec<Vec<f64>>]) -> Vec<(usize, f64)> {
491 let seq_len = if !weights.is_empty() && !weights[0].is_empty() {
492 weights[0][0].len()
493 } else {
494 return Vec::new();
495 };
496
497 let mut token_attention = vec![0.0; seq_len];
498
499 for head_weights in weights {
500 for row in head_weights {
501 for (i, &weight) in row.iter().enumerate() {
502 token_attention[i] += weight;
503 }
504 }
505 }
506
507 let mut indexed: Vec<_> = token_attention.iter().enumerate().map(|(i, &w)| (i, w)).collect();
508 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
509
510 indexed.into_iter().take(10).collect()
511}
512
513fn extract_attention_flows(weights: &[Vec<Vec<f64>>], threshold: f64) -> Vec<AttentionFlow> {
515 let mut flows = Vec::new();
516
517 for (head, head_weights) in weights.iter().enumerate() {
518 for (from, row) in head_weights.iter().enumerate() {
519 for (to, &weight) in row.iter().enumerate() {
520 if weight >= threshold {
521 flows.push(AttentionFlow {
522 from,
523 to,
524 weight,
525 head,
526 });
527 }
528 }
529 }
530 }
531
532 flows.sort_by(|a, b| b.weight.partial_cmp(&a.weight).unwrap_or(std::cmp::Ordering::Equal));
533 flows.into_iter().take(100).collect()
534}
535
536fn weight_to_symbol(weight: f64) -> &'static str {
538 if weight > 0.8 {
539 "█"
540 } else if weight > 0.6 {
541 "▓"
542 } else if weight > 0.4 {
543 "▒"
544 } else if weight > 0.2 {
545 "░"
546 } else {
547 " "
548 }
549}
550
551fn weight_to_color(weight: f64) -> String {
553 let intensity = (weight * 255.0) as u8;
554 format!("rgb(255, {}, {})", 255 - intensity, 255 - intensity)
555}
556
557fn html_escape(s: &str) -> String {
559 s.replace('&', "&")
560 .replace('<', "<")
561 .replace('>', ">")
562 .replace('"', """)
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[test]
570 fn test_attention_visualizer_creation() {
571 let visualizer = AttentionVisualizer::new();
572 assert_eq!(visualizer.num_layers(), 0);
573 }
574
575 #[test]
576 fn test_register_attention() {
577 let mut visualizer = AttentionVisualizer::new();
578
579 let weights = vec![vec![
580 vec![0.5, 0.3, 0.2],
581 vec![0.1, 0.6, 0.3],
582 vec![0.2, 0.3, 0.5],
583 ]];
584
585 let tokens = vec!["A".to_string(), "B".to_string(), "C".to_string()];
586
587 visualizer
588 .register(
589 "layer.0",
590 weights,
591 tokens.clone(),
592 tokens,
593 AttentionType::SelfAttention,
594 )
595 .expect("operation failed in test");
596
597 assert_eq!(visualizer.num_layers(), 1);
598 }
599
600 #[test]
601 fn test_create_heatmap() {
602 let mut visualizer = AttentionVisualizer::new();
603
604 let weights = vec![vec![
605 vec![0.5, 0.3, 0.2],
606 vec![0.1, 0.6, 0.3],
607 vec![0.2, 0.3, 0.5],
608 ]];
609
610 let tokens = vec!["A".to_string(), "B".to_string(), "C".to_string()];
611
612 visualizer
613 .register(
614 "layer.0",
615 weights,
616 tokens.clone(),
617 tokens,
618 AttentionType::SelfAttention,
619 )
620 .expect("operation failed in test");
621
622 let heatmap = visualizer.create_heatmap("layer.0", 0).expect("operation failed in test");
623 assert_eq!(heatmap.layer_name, "layer.0");
624 assert_eq!(heatmap.head, 0);
625 assert_eq!(heatmap.weights.len(), 3);
626 }
627
628 #[test]
629 fn test_analyze_attention() {
630 let mut visualizer = AttentionVisualizer::new();
631
632 let weights = vec![vec![
633 vec![0.7, 0.2, 0.1],
634 vec![0.1, 0.8, 0.1],
635 vec![0.1, 0.1, 0.8],
636 ]];
637
638 let tokens = vec!["A".to_string(), "B".to_string(), "C".to_string()];
639
640 visualizer
641 .register(
642 "layer.0",
643 weights,
644 tokens.clone(),
645 tokens,
646 AttentionType::SelfAttention,
647 )
648 .expect("operation failed in test");
649
650 let analysis = visualizer.analyze("layer.0").expect("operation failed in test");
651 assert_eq!(analysis.entropy_per_head.len(), 1);
652 assert_eq!(analysis.sparsity_per_head.len(), 1);
653 assert!(!analysis.most_attended_tokens.is_empty());
654 }
655
656 #[test]
657 fn test_export_to_json() {
658 use std::env;
659
660 let temp_dir = env::temp_dir();
661 let output_path = temp_dir.join("attention.json");
662
663 let mut visualizer = AttentionVisualizer::new();
664 let weights = vec![vec![vec![1.0]]];
665 let tokens = vec!["A".to_string()];
666
667 visualizer
668 .register(
669 "layer.0",
670 weights,
671 tokens.clone(),
672 tokens,
673 AttentionType::SelfAttention,
674 )
675 .expect("operation failed in test");
676
677 visualizer
678 .export_to_json("layer.0", &output_path)
679 .expect("operation failed in test");
680 assert!(output_path.exists());
681
682 let _ = std::fs::remove_file(output_path);
684 }
685
686 #[test]
687 fn test_compute_entropy() {
688 let weights = vec![vec![0.5, 0.3, 0.2], vec![1.0, 0.0, 0.0]];
689
690 let entropy = compute_entropy(&weights);
691 assert!(entropy > 0.0);
692 }
693
694 #[test]
695 fn test_compute_sparsity() {
696 let weights = vec![vec![0.9, 0.05, 0.05], vec![0.01, 0.01, 0.98]];
697
698 let sparsity = compute_sparsity(&weights, 0.1);
699 assert!(sparsity > 0.0);
700 assert!(sparsity <= 1.0);
701 }
702}