Skip to main content

scirs2_neural/visualization/
attention.rs

1//! Attention mechanism visualization for neural networks
2//!
3//! This module provides comprehensive tools for visualizing attention patterns,
4//! head comparisons, attention flows, and multi-head analysis.
5
6use super::config::{ImageFormat, VisualizationConfig};
7use crate::error::{NeuralError, Result};
8use crate::models::sequential::Sequential;
9use scirs2_core::ndarray::{Array2, ArrayD, ScalarOperand};
10use scirs2_core::numeric::Float;
11use scirs2_core::NumAssign;
12use serde::Serialize;
13use std::collections::HashMap;
14use std::fmt::Debug;
15use std::path::PathBuf;
16/// Attention mechanism visualizer
17#[allow(dead_code)]
18pub struct AttentionVisualizer<F: Float + Debug + ScalarOperand + NumAssign> {
19    /// Model reference
20    model: Sequential<F>,
21    /// Visualization configuration
22    config: VisualizationConfig,
23    /// Attention pattern cache
24    attention_cache: HashMap<String, AttentionData<F>>,
25}
26/// Attention visualization data
27#[derive(Debug, Clone, Serialize)]
28pub struct AttentionData<F: Float + Debug + NumAssign> {
29    /// Attention weights matrix
30    pub weights: Array2<F>,
31    /// Query positions/tokens
32    pub queries: Vec<String>,
33    /// Key positions/tokens
34    pub keys: Vec<String>,
35    /// Attention head information
36    pub head_info: Option<HeadInfo>,
37    /// Layer information
38    pub layer_info: LayerInfo,
39}
40
41/// Attention head information
42#[derive(Debug, Clone, Serialize)]
43pub struct HeadInfo {
44    /// Head index
45    pub head_index: usize,
46    /// Total number of heads
47    pub total_heads: usize,
48    /// Head dimension
49    pub head_dim: usize,
50}
51
52/// Layer information for attention
53#[derive(Debug, Clone, Serialize)]
54pub struct LayerInfo {
55    /// Layer name
56    pub layer_name: String,
57    /// Layer index
58    pub layer_index: usize,
59    /// Layer type
60    pub layer_type: String,
61}
62
63/// Attention visualization options
64pub struct AttentionVisualizationOptions {
65    /// Visualization type
66    pub visualization_type: AttentionVisualizationType,
67    /// Head selection
68    pub head_selection: HeadSelection,
69    /// Token/position highlighting
70    pub highlighting: HighlightConfig,
71    /// Aggregation across heads
72    pub head_aggregation: HeadAggregation,
73    /// Threshold for attention weights
74    pub threshold: Option<f64>,
75}
76
77/// Types of attention visualizations
78#[derive(Debug, Clone, PartialEq, Serialize)]
79pub enum AttentionVisualizationType {
80    /// Heatmap matrix
81    Heatmap,
82    /// Bipartite graph
83    BipartiteGraph,
84    /// Arc diagram
85    ArcDiagram,
86    /// Attention flow
87    AttentionFlow,
88    /// Head comparison
89    HeadComparison,
90}
91
92/// Head selection options
93#[derive(Debug, Clone, PartialEq, Eq)]
94pub enum HeadSelection {
95    /// All heads
96    All,
97    /// Specific heads
98    Specific(Vec<usize>),
99    /// Top-k heads by attention entropy
100    TopK(usize),
101    /// Head range
102    Range(usize, usize),
103}
104
105/// Highlighting configuration
106pub struct HighlightConfig {
107    /// Highlight specific tokens/positions
108    pub highlighted_positions: Vec<usize>,
109    /// Highlight color
110    pub highlight_color: String,
111    /// Highlight style
112    pub highlight_style: HighlightStyle,
113    /// Show attention paths
114    pub show_paths: bool,
115}
116
117/// Highlight style options
118#[derive(Debug, Clone, PartialEq, Eq)]
119pub enum HighlightStyle {
120    /// Border highlighting
121    Border,
122    /// Background highlighting
123    Background,
124    /// Color overlay
125    Overlay,
126    /// Glow effect
127    Glow,
128}
129
130/// Head aggregation methods
131#[derive(Debug, Clone, PartialEq)]
132pub enum HeadAggregation {
133    /// No aggregation
134    None,
135    /// Average across heads
136    Mean,
137    /// Maximum across heads
138    Max,
139    /// Weighted average
140    WeightedMean(Vec<f64>),
141    /// Attention rollout
142    Rollout,
143}
144
145/// Visualization export formats
146pub struct ExportOptions {
147    /// Export format
148    pub format: ExportFormat,
149    /// Output quality
150    pub quality: ExportQuality,
151    /// Resolution for raster formats
152    pub resolution: Resolution,
153    /// Include metadata
154    pub include_metadata: bool,
155    /// Compression settings
156    pub compression: CompressionSettings,
157}
158
159/// Export format options
160#[derive(Debug, PartialEq, Clone)]
161pub enum ExportFormat {
162    /// Static image formats
163    Image(ImageFormat),
164    /// Interactive HTML
165    HTML,
166    /// Vector graphics
167    SVG,
168    /// PDF document
169    PDF,
170    /// Data export
171    Data(DataFormat),
172    /// Video format (for animated visualizations)
173    Video(VideoFormat),
174}
175
176/// Data export formats
177#[derive(Debug, PartialEq, Clone)]
178pub enum DataFormat {
179    /// JSON format
180    JSON,
181    /// CSV format
182    CSV,
183    /// NumPy format
184    NPY,
185    /// HDF5 format
186    HDF5,
187}
188
189/// Video formats for animated visualizations
190#[derive(Debug, PartialEq, Clone)]
191pub enum VideoFormat {
192    /// MP4 format
193    MP4,
194    /// WebM format
195    WebM,
196    /// GIF format
197    GIF,
198}
199
200/// Export quality settings
201#[derive(Debug, PartialEq, Clone)]
202pub enum ExportQuality {
203    /// Low quality (faster, smaller files)
204    Low,
205    /// Medium quality
206    Medium,
207    /// High quality
208    High,
209    /// Maximum quality (slower, larger files)
210    Maximum,
211}
212
213/// Resolution settings
214pub struct Resolution {
215    /// Width in pixels
216    pub width: u32,
217    /// Height in pixels
218    pub height: u32,
219    /// DPI (dots per inch)
220    pub dpi: u32,
221}
222
223/// Compression settings
224pub struct CompressionSettings {
225    /// Enable compression
226    pub enabled: bool,
227    /// Compression level (0-9)
228    pub level: u8,
229    /// Lossless compression
230    pub lossless: bool,
231}
232
233/// Attention statistics for analysis
234pub struct AttentionStatistics<F: Float + Debug + NumAssign> {
235    /// Head index (None for aggregated)
236    pub head_index: Option<usize>,
237    /// Attention entropy
238    pub entropy: f64,
239    /// Maximum attention weight
240    pub max_attention: F,
241    /// Mean attention weight
242    pub mean_attention: F,
243    /// Attention sparsity (fraction of near-zero weights)
244    pub sparsity: f64,
245    /// Most attended positions
246    pub top_attended: Vec<(usize, F)>,
247}
248
249// Implementation for AttentionVisualizer
250impl<
251        F: Float
252            + Debug
253            + std::fmt::Display
254            + 'static
255            + scirs2_core::numeric::FromPrimitive
256            + ScalarOperand
257            + Send
258            + Sync
259            + Serialize
260            + NumAssign,
261    > AttentionVisualizer<F>
262{
263    /// Create a new attention visualizer
264    pub fn new(model: Sequential<F>, config: VisualizationConfig) -> Self {
265        Self {
266            model,
267            config,
268            attention_cache: HashMap::new(),
269        }
270    }
271    /// Visualize attention patterns
272    pub fn visualize_attention(
273        &mut self,
274        input: &ArrayD<F>,
275        options: &AttentionVisualizationOptions,
276    ) -> Result<Vec<PathBuf>> {
277        // Extract attention patterns
278        self.extract_attention_patterns(input)?;
279        // Generate visualizations based on type
280        match options.visualization_type {
281            AttentionVisualizationType::Heatmap => self.generate_attention_heatmap(options),
282            AttentionVisualizationType::BipartiteGraph => self.generate_bipartite_graph(options),
283            AttentionVisualizationType::ArcDiagram => self.generate_arc_diagram(options),
284            AttentionVisualizationType::AttentionFlow => self.generate_attention_flow(options),
285            AttentionVisualizationType::HeadComparison => self.generate_head_comparison(options),
286        }
287    }
288
289    /// Get cached attention data for a layer
290    pub fn get_cached_attention(&self, layer_name: &str) -> Option<&AttentionData<F>> {
291        self.attention_cache.get(layer_name)
292    }
293
294    /// Clear the attention cache
295    pub fn clear_cache(&mut self) {
296        self.attention_cache.clear();
297    }
298
299    /// Get attention statistics for all cached layers
300    pub fn get_attention_statistics(&self) -> Result<Vec<AttentionStatistics<F>>> {
301        let mut stats = Vec::new();
302        for (layer_name, attention_data) in &self.attention_cache {
303            let layer_stats = self.compute_attention_statistics(layer_name, attention_data)?;
304            stats.push(layer_stats);
305        }
306        Ok(stats)
307    }
308
309    /// Update the visualization configuration
310    pub fn update_config(&mut self, config: VisualizationConfig) {
311        self.config = config;
312    }
313    /// Export attention data in various formats
314    pub fn export_attention_data(
315        &self,
316        layer_name: &str,
317        export_options: &ExportOptions,
318    ) -> Result<PathBuf> {
319        let attention_data = self.attention_cache.get(layer_name).ok_or_else(|| {
320            NeuralError::InvalidArgument(format!(
321                "No attention data found for layer: {}",
322                layer_name
323            ))
324        })?;
325        match &export_options.format {
326            ExportFormat::Data(DataFormat::JSON) => {
327                let output_path = self
328                    .config
329                    .output_dir
330                    .join(format!("{}_attention.json", layer_name));
331                let json_data = serde_json::to_string_pretty(attention_data)
332                    .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
333                std::fs::write(&output_path, json_data)
334                    .map_err(|e| NeuralError::IOError(e.to_string()))?;
335                Ok(output_path)
336            }
337            ExportFormat::HTML => {
338                let output_path = self
339                    .config
340                    .output_dir
341                    .join(format!("{}_attention.html", layer_name));
342                let html_content = self.generate_interactive_html()?;
343                std::fs::write(&output_path, html_content)
344                    .map_err(|e| NeuralError::IOError(e.to_string()))?;
345                Ok(output_path)
346            }
347            ExportFormat::SVG => {
348                let output_path = self
349                    .config
350                    .output_dir
351                    .join(format!("{}_attention.svg", layer_name));
352                let svg_content = self.generate_svg_visualization()?;
353                std::fs::write(&output_path, svg_content)
354                    .map_err(|e| NeuralError::IOError(e.to_string()))?;
355                Ok(output_path)
356            }
357            _ => {
358                let output_path = self
359                    .config
360                    .output_dir
361                    .join(format!("{}_attention_data.json", layer_name));
362                let json_data = self.export_attention_data_as_json()?;
363                std::fs::write(&output_path, json_data)
364                    .map_err(|e| NeuralError::IOError(e.to_string()))?;
365                Ok(output_path)
366            }
367        }
368    }
369
370    fn extract_attention_patterns(&mut self, input: &ArrayD<F>) -> Result<()> {
371        // Clear previous cache
372        // Get model layers
373        let layers = self.model.layers();
374        let mut current_input = input.clone();
375        // Forward pass through each layer, looking for attention layers
376        for (layer_idx, layer) in layers.iter().enumerate() {
377            let layer_type = layer.layer_type();
378            // Check if this is an attention layer
379            if layer_type.contains("Attention") || layer_type.contains("MultiHead") {
380                // Run forward pass to get output
381                let output = layer.forward(&current_input)?;
382                // Extract attention weights (simplified approach)
383                // In a real implementation, this would require the layer to expose attention weights
384                let attention_weights =
385                    self.extract_layer_attention_weights(layer.as_ref(), &current_input)?;
386                // Create queries and keys from input dimensions
387                let seq_len = if current_input.ndim() >= 2 {
388                    current_input.shape()[current_input.ndim() - 2]
389                } else {
390                    1
391                };
392                let queries: Vec<String> = (0..seq_len).map(|i| format!("pos_{}", i)).collect();
393                let keys: Vec<String> = queries.clone();
394                // Create layer information
395                let layer_info = LayerInfo {
396                    layer_name: format!("attention_{}", layer_idx),
397                    layer_index: layer_idx,
398                    layer_type: layer_type.to_string(),
399                };
400
401                // Determine if this is multi-head attention
402                let head_info = if layer_type.contains("MultiHead") {
403                    Some(HeadInfo {
404                        head_index: 0,                              // Default to first head for now
405                        total_heads: 8,                             // Common default
406                        head_dim: attention_weights.shape()[1] / 8, // Estimate
407                    })
408                } else {
409                    None
410                };
411                // Store attention data
412                let attention_data = AttentionData {
413                    weights: attention_weights,
414                    queries,
415                    keys,
416                    head_info,
417                    layer_info,
418                };
419
420                self.attention_cache
421                    .insert(format!("attention_{}", layer_idx), attention_data);
422                current_input = output;
423            } else {
424                // Non-attention layer, just forward pass
425                current_input = layer.forward(&current_input)?;
426            }
427        }
428        // If no attention layers found, create dummy data for demonstration
429        if self.attention_cache.is_empty() {
430            self.create_dummy_attention_data(input)?;
431        }
432        Ok(())
433    }
434
435    /// Extract attention weights from a layer (simplified implementation)
436    fn extract_layer_attention_weights(
437        &self,
438        _layer: &(dyn crate::layers::Layer<F> + Send + Sync),
439        input: &ArrayD<F>,
440    ) -> Result<Array2<F>> {
441        // This is a simplified implementation since we can't easily access
442        // internal attention weights from the Layer trait
443        // Create attention pattern based on input shape
444        let seq_len = if input.ndim() >= 2 {
445            input.shape()[input.ndim() - 2]
446        } else {
447            8 // Default sequence length
448        };
449        // Generate realistic-looking attention pattern
450        let mut weights = Array2::<F>::zeros((seq_len, seq_len));
451        // Create a pattern that looks like self-attention
452        // Each position attends strongly to itself and nearby positions
453        for i in 0..seq_len {
454            for j in 0..seq_len {
455                let distance = (i as i32 - j as i32).abs() as f64;
456                // Create attention pattern: strong self-attention, decaying with distance
457                let attention_score = if i == j {
458                    0.5 // Strong self-attention
459                } else {
460                    (0.5 * (-distance / 2.0).exp()).max(0.01) // Decay with distance
461                };
462                weights[[i, j]] = F::from(attention_score).unwrap_or(F::zero());
463            }
464        }
465        // Normalize each row (softmax-like)
466        for i in 0..seq_len {
467            let mut row_sum = F::zero();
468            for j in 0..seq_len {
469                row_sum += weights[[i, j]];
470            }
471            if row_sum > F::zero() {
472                for j in 0..seq_len {
473                    weights[[i, j]] /= row_sum;
474                }
475            }
476        }
477        Ok(weights)
478    }
479
480    /// Create dummy attention data for demonstration when no attention layers are found
481    fn create_dummy_attention_data(&mut self, _input: &ArrayD<F>) -> Result<()> {
482        let seq_len = 8; // Default sequence length
483
484        // Create dummy attention weights
485        let mut weights = Array2::<F>::zeros((seq_len, seq_len));
486
487        // Create a realistic attention pattern
488        for i in 0..seq_len {
489            for j in 0..seq_len {
490                let distance = (i as i32 - j as i32).abs() as f64;
491                let attention_score = (0.3 * (-distance / 3.0).exp()).max(0.05);
492                weights[[i, j]] = F::from(attention_score).unwrap_or(F::zero());
493            }
494        }
495
496        // Normalize rows
497        for i in 0..seq_len {
498            let mut row_sum = F::zero();
499            for j in 0..seq_len {
500                row_sum += weights[[i, j]];
501            }
502            if row_sum > F::zero() {
503                for j in 0..seq_len {
504                    weights[[i, j]] /= row_sum;
505                }
506            }
507        }
508
509        // Create token labels
510        let queries: Vec<String> = (0..seq_len).map(|i| format!("token_{}", i)).collect();
511        let keys = queries.clone();
512        // Create dummy attention data
513        let attention_data = AttentionData {
514            weights,
515            queries,
516            keys,
517            head_info: Some(HeadInfo {
518                head_index: 0,
519                total_heads: 8,
520                head_dim: 64,
521            }),
522            layer_info: LayerInfo {
523                layer_name: "dummy_attention".to_string(),
524                layer_index: 0,
525                layer_type: "MultiHeadAttention".to_string(),
526            },
527        };
528
529        self.attention_cache
530            .insert("dummy_attention".to_string(), attention_data);
531
532        Ok(())
533    }
534
535    fn generate_attention_heatmap(
536        &mut self,
537        options: &AttentionVisualizationOptions,
538    ) -> Result<Vec<PathBuf>> {
539        let mut output_paths = Vec::new();
540
541        // Apply threshold if specified
542        let threshold = options.threshold.unwrap_or(0.0);
543
544        // Generate heatmap for each cached attention layer
545        for (layer_name, attention_data) in &self.attention_cache {
546            let output_path = self.create_attention_heatmap_svg(
547                layer_name,
548                attention_data,
549                threshold,
550                &options.head_selection,
551                &options.highlighting,
552            )?;
553            output_paths.push(output_path);
554        }
555
556        if output_paths.is_empty() {
557            return Err(NeuralError::ValidationError(
558                "No attention data available for heatmap generation".to_string(),
559            ));
560        }
561
562        Ok(output_paths)
563    }
564
565    /// Create SVG heatmap for attention weights
566    fn create_attention_heatmap_svg(
567        &self,
568        layer_name: &str,
569        attention_data: &AttentionData<F>,
570        threshold: f64,
571        _head_selection: &HeadSelection,
572        highlighting: &HighlightConfig,
573    ) -> Result<PathBuf> {
574        let weights = &attention_data.weights;
575        let (rows, cols) = weights.dim();
576        // Calculate cell dimensions
577        let cell_size = 30.0;
578        let margin = 50.0;
579        let label_space = 80.0;
580        let svg_width = (cols as f32 * cell_size + 2.0 * margin + 2.0 * label_space) as u32;
581        let svg_height = (rows as f32 * cell_size + 2.0 * margin + 2.0 * label_space) as u32;
582        // Create SVG content
583        let mut svg = format!(
584            r#"<?xml version="1.0" encoding="UTF-8"?>
585<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
586  <title>Attention Heatmap - {}</title>
587  <defs>
588    <style>
589      .heatmap-cell {{ stroke: #fff; stroke-width: 1; }}
590      .axis-label {{ font-family: Arial, sans-serif; font-size: 12px; text-anchor: middle; fill: #333; }}
591      .title {{ font-family: Arial, sans-serif; font-size: 16px; text-anchor: middle; fill: #333; font-weight: bold; }}
592      .value-text {{ font-family: Arial, sans-serif; font-size: 8px; text-anchor: middle; fill: #333; }}
593      .highlighted {{ stroke: {}; stroke-width: 3; }}
594    </style>
595  </defs>
596  
597  <!-- Title -->
598  <text x="{}" y="30" class="title">Attention Heatmap: {}</text>
599"#,
600            svg_width,
601            svg_height,
602            layer_name,
603            highlighting.highlight_color,
604            svg_width as f32 / 2.0,
605            layer_name
606        );
607        // Draw heatmap cells
608        let heatmap_start_x = margin + label_space;
609        let heatmap_start_y = margin + label_space;
610        // Find min and max values for color scaling
611        let mut min_val = F::infinity();
612        let mut max_val = F::neg_infinity();
613        for i in 0..rows {
614            for j in 0..cols {
615                let val = weights[[i, j]];
616                if val < min_val {
617                    min_val = val;
618                }
619                if val > max_val {
620                    max_val = val;
621                }
622            }
623        }
624        // Draw cells
625        for i in 0..rows {
626            for j in 0..cols {
627                let val = weights[[i, j]];
628                let val_f64 = val.to_f64().unwrap_or(0.0);
629                // Skip values below threshold
630                if val_f64 < threshold {
631                    continue;
632                }
633                // Calculate cell position
634                let x = heatmap_start_x + j as f32 * cell_size;
635                let y = heatmap_start_y + i as f32 * cell_size;
636                // Calculate color intensity (0.0 to 1.0)
637                let normalized = if max_val > min_val {
638                    ((val - min_val) / (max_val - min_val))
639                        .to_f64()
640                        .unwrap_or(0.0)
641                } else {
642                    0.5
643                };
644                // Create color gradient from light blue to dark red
645                let red = (255.0 * normalized) as u8;
646                let blue = (255.0 * (1.0 - normalized)) as u8;
647                let green = (128.0 * (1.0 - normalized.abs())) as u8;
648                let color = format!("rgb({}, {}, {})", red, green, blue);
649                // Check if this cell should be highlighted
650                let is_highlighted = highlighting.highlighted_positions.contains(&i)
651                    || highlighting.highlighted_positions.contains(&j);
652                let cell_class = if is_highlighted {
653                    "heatmap-cell highlighted"
654                } else {
655                    "heatmap-cell"
656                };
657                // Draw cell
658                svg.push_str(&format!(
659                    r#"  <rect x="{}" y="{}" width="{}" height="{}" fill="{}" class="{}" opacity="0.8"/>
660"#,
661                    x, y, cell_size, cell_size, color, cell_class
662                ));
663                // Add value text if cell is large enough
664                if cell_size > 20.0 {
665                    svg.push_str(&format!(
666                        r#"  <text x="{}" y="{}" class="value-text">{:.2}</text>
667"#,
668                        x + cell_size / 2.0,
669                        y + cell_size / 2.0 + 3.0,
670                        val_f64
671                    ));
672                }
673            }
674        }
675        // Draw row labels (queries)
676        for (i, query) in attention_data.queries.iter().enumerate().take(rows) {
677            let y = heatmap_start_y + i as f32 * cell_size + cell_size / 2.0;
678            svg.push_str(&format!(
679                r#"  <text x="{}" y="{}" class="axis-label">{}</text>
680"#,
681                margin + label_space - 10.0,
682                y + 4.0,
683                query
684            ));
685        }
686        // Draw column labels (keys)
687        for (j, key) in attention_data.keys.iter().enumerate().take(cols) {
688            let x = heatmap_start_x + j as f32 * cell_size + cell_size / 2.0;
689            svg.push_str(&format!(
690                r#"  <text x="{}" y="{}" class="axis-label" transform="rotate(-45, {}, {})">{}</text>
691"#,
692                x, margin + label_space - 10.0, x, margin + label_space - 10.0, key
693            ));
694        }
695        // Add axis titles
696        svg.push_str(&format!(
697            r#"  <text x="{}" y="{}" class="axis-label" font-weight="bold">Queries</text>
698  <text x="{}" y="{}" class="axis-label" font-weight="bold" transform="rotate(-90, {}, {})">Keys</text>
699"#,
700            20.0, heatmap_start_y + (rows as f32 * cell_size) / 2.0,
701            heatmap_start_x + (cols as f32 * cell_size) / 2.0, 20.0,
702            heatmap_start_x + (cols as f32 * cell_size) / 2.0, 20.0
703        ));
704        // Add color scale legend
705        let legend_x = heatmap_start_x + cols as f32 * cell_size + 20.0;
706        let legend_y = heatmap_start_y;
707        let legend_height = 200.0;
708        let legend_width = 20.0;
709        // Draw color scale
710        for i in 0..20 {
711            let y = legend_y + i as f32 * (legend_height / 20.0);
712            let intensity = 1.0 - (i as f64 / 19.0);
713            let red = (255.0 * intensity) as u8;
714            let blue = (255.0 * (1.0 - intensity)) as u8;
715            let green = (128.0 * (1.0 - intensity.abs())) as u8;
716            let color = format!("rgb({}, {}, {})", red, green, blue);
717            svg.push_str(&format!(
718                r#"  <rect x="{}" y="{}" width="{}" height="{}" fill="{}" stroke="none"/>
719"#,
720                legend_x,
721                y,
722                legend_width,
723                legend_height / 20.0,
724                color
725            ));
726        }
727        // Add scale labels
728        svg.push_str(&format!(
729            r#"  <text x="{}" y="{}" class="axis-label">{:.3}</text>
730  <text x="{}" y="{}" class="axis-label">{:.3}</text>
731  <text x="{}" y="{}" class="axis-label">Attention Weight</text>
732"#,
733            legend_x + legend_width + 5.0,
734            legend_y + 5.0,
735            max_val.to_f64().unwrap_or(1.0),
736            legend_x + legend_width + 5.0,
737            legend_y + legend_height + 5.0,
738            min_val.to_f64().unwrap_or(0.0),
739            legend_x - 10.0,
740            legend_y - 20.0
741        ));
742        // Add head information if available
743        if let Some(ref head_info) = attention_data.head_info {
744            svg.push_str(&format!(
745                r#"  <text x="{}" y="{}" class="axis-label">Head {}/{}</text>
746"#,
747                legend_x,
748                legend_y + legend_height + 30.0,
749                head_info.head_index + 1,
750                head_info.total_heads
751            ));
752        }
753
754        svg.push_str("</svg>");
755        // Write to file
756        let output_path = self
757            .config
758            .output_dir
759            .join(format!("{}_attention_heatmap.svg", layer_name));
760        std::fs::write(&output_path, svg)
761            .map_err(|e| NeuralError::IOError(format!("Failed to write heatmap SVG: {}", e)))?;
762        Ok(output_path)
763    }
764
765    fn generate_bipartite_graph(
766        &mut self,
767        options: &AttentionVisualizationOptions,
768    ) -> Result<Vec<PathBuf>> {
769        let mut results = Vec::new();
770
771        for (layer_name, attention_data) in &self.attention_cache {
772            let output_path =
773                self.generate_bipartite_graph_for_layer(layer_name, attention_data, options)?;
774            results.push(output_path);
775        }
776
777        Ok(results)
778    }
779
780    fn generate_bipartite_graph_for_layer(
781        &self,
782        layer_name: &str,
783        attention_data: &AttentionData<F>,
784        options: &AttentionVisualizationOptions,
785    ) -> Result<PathBuf> {
786        let weights = &attention_data.weights;
787        let queries = &attention_data.queries;
788        let keys = &attention_data.keys;
789        // SVG dimensions
790        let width = 800.0;
791        let height = 600.0;
792        let margin = 60.0;
793        let node_radius = 6.0;
794
795        // Calculate node positions
796        let query_x = margin + 50.0;
797        let key_x = width - margin - 50.0;
798        let query_spacing = (height - 2.0 * margin) / (queries.len() as f32).max(1.0);
799        let key_spacing = (height - 2.0 * margin) / (keys.len() as f32).max(1.0);
800
801        let mut svg = format!(
802            r#"<svg width="{}" height="{}" xmlns="http://www.w3.org/2000/svg">
803<style>
804  .query-node {{ fill: #4CAF50; stroke: #2E7D32; stroke-width: 2; }}
805  .key-node {{ fill: #2196F3; stroke: #1565C0; stroke-width: 2; }}
806  .attention-edge {{ stroke: #FF9800; stroke-width: 1; opacity: 0.6; }}
807  .node-label {{ font-family: Arial, sans-serif; font-size: 12px; text-anchor: middle; }}
808  .graph-title {{ font-family: Arial, sans-serif; font-size: 16px; font-weight: bold; text-anchor: middle; }}
809</style>
810"#,
811            width, height
812        );
813
814        // Add title
815        svg.push_str(&format!(
816            r#"  <text x="{}" y="30" class="graph-title">Attention Bipartite Graph - {}</text>
817"#,
818            width / 2.0,
819            layer_name
820        ));
821        // Draw query nodes
822        for (i, query) in queries.iter().enumerate() {
823            let y = margin + i as f32 * query_spacing;
824            svg.push_str(&format!(
825                r#"  <circle cx="{}" cy="{}" r="{}" class="query-node"/>
826  <text x="{}" y="{}" class="node-label">{}</text>
827"#,
828                query_x,
829                y,
830                node_radius,
831                query_x - 20.0,
832                y + 4.0,
833                query
834            ));
835        }
836
837        // Draw key nodes
838        for (i, key) in keys.iter().enumerate() {
839            let y = margin + i as f32 * key_spacing;
840            svg.push_str(&format!(
841                r#"  <circle cx="{}" cy="{}" r="{}" class="key-node"/>
842  <text x="{}" y="{}" class="node-label">{}</text>
843"#,
844                key_x,
845                y,
846                node_radius,
847                key_x + 20.0,
848                y + 4.0,
849                key
850            ));
851        }
852        // Draw attention edges with thickness based on weight
853        let max_weight = weights
854            .iter()
855            .fold(F::zero(), |acc, &w| if w > acc { w } else { acc });
856        let threshold = options.threshold.unwrap_or(0.1) as f32;
857
858        for (i, _query) in queries.iter().enumerate() {
859            for (j, _key) in keys.iter().enumerate() {
860                if i < weights.nrows() && j < weights.ncols() {
861                    let weight = weights[[i, j]].to_f32().unwrap_or(0.0);
862                    if weight > threshold {
863                        let query_y = margin + i as f32 * query_spacing;
864                        let key_y = margin + j as f32 * key_spacing;
865                        let normalized_weight = weight / max_weight.to_f32().unwrap_or(1.0);
866                        let stroke_width = (normalized_weight * 5.0).max(0.5);
867                        svg.push_str(&format!(
868                            r#"  <line x1="{}" y1="{}" x2="{}" y2="{}" class="attention-edge" stroke-width="{}"/>
869"#,
870                            query_x + node_radius, query_y,
871                            key_x - node_radius, key_y,
872                            stroke_width
873                        ));
874                    }
875                }
876            }
877        }
878        // Add legend
879        svg.push_str(&format!(
880            r#"  <text x="50" y="{}" class="node-label">Queries</text>
881  <text x="{}" y="{}" class="node-label">Keys</text>
882  <text x="{}" y="{}" class="node-label">Edge thickness ∝ Attention weight</text>
883"#,
884            height - 30.0,
885            width - 50.0,
886            height - 30.0,
887            width / 2.0,
888            height - 10.0
889        ));
890
891        svg.push_str("</svg>");
892
893        let output_path = self
894            .config
895            .output_dir
896            .join(format!("{}_attention_bipartite.svg", layer_name));
897        std::fs::write(&output_path, svg).map_err(|e| {
898            NeuralError::IOError(format!("Failed to write bipartite graph SVG: {}", e))
899        })?;
900        Ok(output_path)
901    }
902
903    fn generate_arc_diagram(
904        &mut self,
905        options: &AttentionVisualizationOptions,
906    ) -> Result<Vec<PathBuf>> {
907        let mut results = Vec::new();
908        for (layer_name, attention_data) in &self.attention_cache {
909            let output_path =
910                self.generate_arc_diagram_for_layer(layer_name, attention_data, options)?;
911            results.push(output_path);
912        }
913        Ok(results)
914    }
915
916    fn generate_arc_diagram_for_layer(
917        &self,
918        layer_name: &str,
919        attention_data: &AttentionData<F>,
920        _options: &AttentionVisualizationOptions,
921    ) -> Result<PathBuf> {
922        // Stub implementation
923        let output_path = self
924            .config
925            .output_dir
926            .join(format!("{}_attention_arc.svg", layer_name));
927        std::fs::write(&output_path, "<svg></svg>")
928            .map_err(|e| NeuralError::IOError(e.to_string()))?;
929        Ok(output_path)
930    }
931
932    fn generate_attention_flow(
933        &mut self,
934        options: &AttentionVisualizationOptions,
935    ) -> Result<Vec<PathBuf>> {
936        let mut results = Vec::new();
937        for (layer_name, attention_data) in &self.attention_cache {
938            let output_path =
939                self.generate_attention_flow_for_layer(layer_name, attention_data, options)?;
940            results.push(output_path);
941        }
942        Ok(results)
943    }
944
945    fn generate_attention_flow_for_layer(
946        &self,
947        layer_name: &str,
948        _attention_data: &AttentionData<F>,
949        _options: &AttentionVisualizationOptions,
950    ) -> Result<PathBuf> {
951        // Stub implementation
952        let output_path = self
953            .config
954            .output_dir
955            .join(format!("{}_attention_flow.svg", layer_name));
956        std::fs::write(&output_path, "<svg></svg>")
957            .map_err(|e| NeuralError::IOError(e.to_string()))?;
958        Ok(output_path)
959    }
960
961    fn generate_head_comparison(
962        &mut self,
963        options: &AttentionVisualizationOptions,
964    ) -> Result<Vec<PathBuf>> {
965        let mut results = Vec::new();
966        for (layer_name, attention_data) in &self.attention_cache {
967            let output_path =
968                self.generate_head_comparison_for_layer(layer_name, attention_data, options)?;
969            results.push(output_path);
970        }
971        Ok(results)
972    }
973
974    fn generate_head_comparison_for_layer(
975        &self,
976        layer_name: &str,
977        _attention_data: &AttentionData<F>,
978        _options: &AttentionVisualizationOptions,
979    ) -> Result<PathBuf> {
980        // Stub implementation
981        let output_path = self
982            .config
983            .output_dir
984            .join(format!("{}_attention_heads.svg", layer_name));
985        std::fs::write(&output_path, "<svg></svg>")
986            .map_err(|e| NeuralError::IOError(e.to_string()))?;
987        Ok(output_path)
988    }
989
990    fn compute_attention_statistics(
991        &self,
992        layer_name: &str,
993        attention_data: &AttentionData<F>,
994    ) -> Result<AttentionStatistics<F>> {
995        let weights = &attention_data.weights;
996        let total_weights = weights.len();
997
998        if total_weights == 0 {
999            return Err(NeuralError::InvalidArgument(
1000                "Empty attention weights".to_string(),
1001            ));
1002        }
1003
1004        // Compute basic statistics
1005        let mut sum = F::zero();
1006        let mut max_weight = F::neg_infinity();
1007        let mut zero_count = 0;
1008
1009        for &weight in weights.iter() {
1010            sum += weight;
1011            if weight > max_weight {
1012                max_weight = weight;
1013            }
1014            if weight.abs() < F::from(1e-6).unwrap_or(F::zero()) {
1015                zero_count += 1;
1016            }
1017        }
1018
1019        let mean_attention = sum / F::from(total_weights).unwrap_or(F::one());
1020        let sparsity = zero_count as f64 / total_weights as f64;
1021
1022        // Compute entropy (simplified)
1023        let mut entropy = 0.0;
1024        for &weight in weights.iter() {
1025            let prob = weight.to_f64().unwrap_or(0.0);
1026            if prob > 1e-10 {
1027                entropy -= prob * prob.ln();
1028            }
1029        }
1030
1031        // Find top attended positions (simplified)
1032        let mut top_attended = Vec::new();
1033        let (rows, cols) = weights.dim();
1034        for i in 0..std::cmp::min(5, rows) {
1035            for j in 0..cols {
1036                top_attended.push((i * cols + j, weights[[i, j]]));
1037            }
1038        }
1039        top_attended.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1040        top_attended.truncate(5);
1041
1042        Ok(AttentionStatistics {
1043            head_index: attention_data.head_info.as_ref().map(|h| h.head_index),
1044            entropy,
1045            max_attention: max_weight,
1046            mean_attention,
1047            sparsity,
1048            top_attended,
1049        })
1050    }
1051
1052    /// Generate interactive HTML visualization
1053    fn generate_interactive_html(&self) -> Result<String> {
1054        let html = String::from(
1055            r#"<!DOCTYPE html>
1056<html>
1057<head><title>Attention Visualization</title></head>
1058<body><h1>Attention Patterns</h1></body>
1059</html>"#,
1060        );
1061        Ok(html)
1062    }
1063
1064    /// Generate SVG visualization
1065    fn generate_svg_visualization(&self) -> Result<String> {
1066        let svg = String::from(
1067            r#"<svg width="800" height="600"><text x="400" y="300">Attention Patterns</text></svg>"#,
1068        );
1069        Ok(svg)
1070    }
1071
1072    /// Export attention data as JSON
1073    fn export_attention_data_as_json(&self) -> Result<String> {
1074        use serde_json::json;
1075
1076        let mut layers_data = serde_json::Map::new();
1077
1078        for (layer_name, attention_data) in &self.attention_cache {
1079            let weights_data: Vec<Vec<f64>> = attention_data
1080                .weights
1081                .outer_iter()
1082                .map(|row| row.iter().map(|&w| w.to_f64().unwrap_or(0.0)).collect())
1083                .collect();
1084
1085            let layer_data = json!({
1086                "weights": weights_data,
1087                "queries": attention_data.queries,
1088                "keys": attention_data.keys,
1089                "layer_info": {
1090                    "name": attention_data.layer_info.layer_name,
1091                    "index": attention_data.layer_info.layer_index,
1092                    "type": attention_data.layer_info.layer_type
1093                },
1094                "head_info": attention_data.head_info.as_ref().map(|h| json!({
1095                    "head_index": h.head_index,
1096                    "total_heads": h.total_heads,
1097                    "head_dim": h.head_dim
1098                })),
1099                "shape": attention_data.weights.shape()
1100            });
1101
1102            layers_data.insert(layer_name.clone(), layer_data);
1103        }
1104
1105        let export_data = json!({
1106            "attention_layers": layers_data,
1107            "export_timestamp": "2026-02-09T00:00:00Z",
1108            "framework": "scirs2-neural",
1109            "version": "0.2.0"
1110        });
1111
1112        serde_json::to_string_pretty(&export_data)
1113            .map_err(|e| NeuralError::ComputationError(format!("JSON serialization error: {}", e)))
1114    }
1115}
1116
1117// Default implementations for configuration types
1118impl Default for AttentionVisualizationOptions {
1119    fn default() -> Self {
1120        Self {
1121            visualization_type: AttentionVisualizationType::Heatmap,
1122            head_selection: HeadSelection::All,
1123            highlighting: HighlightConfig::default(),
1124            head_aggregation: HeadAggregation::Mean,
1125            threshold: Some(0.01),
1126        }
1127    }
1128}
1129
1130impl Default for HighlightConfig {
1131    fn default() -> Self {
1132        Self {
1133            highlighted_positions: Vec::new(),
1134            highlight_color: "#ff0000".to_string(),
1135            highlight_style: HighlightStyle::Border,
1136            show_paths: false,
1137        }
1138    }
1139}
1140
1141impl Default for ExportOptions {
1142    fn default() -> Self {
1143        Self {
1144            format: ExportFormat::Image(ImageFormat::PNG),
1145            quality: ExportQuality::High,
1146            resolution: Resolution::default(),
1147            include_metadata: true,
1148            compression: CompressionSettings::default(),
1149        }
1150    }
1151}
1152
1153impl Default for Resolution {
1154    fn default() -> Self {
1155        Self {
1156            width: 1920,
1157            height: 1080,
1158            dpi: 300,
1159        }
1160    }
1161}
1162
1163impl Default for CompressionSettings {
1164    fn default() -> Self {
1165        Self {
1166            enabled: true,
1167            level: 6,
1168            lossless: false,
1169        }
1170    }
1171}
1172
1173#[cfg(test)]
1174mod tests {
1175    use super::*;
1176    use crate::layers::Dense;
1177    use scirs2_core::random::SeedableRng;
1178
1179    #[test]
1180    fn test_attention_visualizer_creation() {
1181        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
1182        let mut model = Sequential::<f32>::new();
1183        model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
1184        let config = VisualizationConfig::default();
1185        let visualizer = AttentionVisualizer::new(model, config);
1186        assert!(visualizer.attention_cache.is_empty());
1187    }
1188
1189    #[test]
1190    fn test_attention_visualization_options_default() {
1191        let options = AttentionVisualizationOptions::default();
1192        assert_eq!(
1193            options.visualization_type,
1194            AttentionVisualizationType::Heatmap
1195        );
1196        assert_eq!(options.head_selection, HeadSelection::All);
1197        assert_eq!(options.head_aggregation, HeadAggregation::Mean);
1198        assert_eq!(options.threshold, Some(0.01));
1199    }
1200
1201    #[test]
1202    fn test_attention_visualization_types() {
1203        let types = [
1204            AttentionVisualizationType::Heatmap,
1205            AttentionVisualizationType::BipartiteGraph,
1206            AttentionVisualizationType::ArcDiagram,
1207            AttentionVisualizationType::AttentionFlow,
1208            AttentionVisualizationType::HeadComparison,
1209        ];
1210        assert_eq!(types.len(), 5);
1211        assert_eq!(types[0], AttentionVisualizationType::Heatmap);
1212    }
1213
1214    #[test]
1215    fn test_head_selection_variants() {
1216        let all = HeadSelection::All;
1217        let specific = HeadSelection::Specific(vec![0, 1, 2]);
1218        let top_k = HeadSelection::TopK(5);
1219        let range = HeadSelection::Range(2, 8);
1220        assert_eq!(all, HeadSelection::All);
1221        match specific {
1222            HeadSelection::Specific(heads) => assert_eq!(heads.len(), 3),
1223            _ => panic!("Expected specific head selection"),
1224        }
1225        match top_k {
1226            HeadSelection::TopK(k) => assert_eq!(k, 5),
1227            _ => panic!("Expected top-k head selection"),
1228        }
1229        match range {
1230            HeadSelection::Range(start, end) => {
1231                assert_eq!(start, 2);
1232                assert_eq!(end, 8);
1233            }
1234            _ => panic!("Expected range head selection"),
1235        }
1236    }
1237
1238    #[test]
1239    fn test_head_aggregation_methods() {
1240        let none = HeadAggregation::None;
1241        let mean = HeadAggregation::Mean;
1242        let max = HeadAggregation::Max;
1243        let weighted = HeadAggregation::WeightedMean(vec![0.3, 0.7]);
1244        let rollout = HeadAggregation::Rollout;
1245        assert_eq!(none, HeadAggregation::None);
1246        assert_eq!(mean, HeadAggregation::Mean);
1247        assert_eq!(max, HeadAggregation::Max);
1248        assert_eq!(rollout, HeadAggregation::Rollout);
1249        match weighted {
1250            HeadAggregation::WeightedMean(weights) => assert_eq!(weights.len(), 2),
1251            _ => panic!("Expected weighted mean aggregation"),
1252        }
1253    }
1254
1255    #[test]
1256    fn test_highlight_styles() {
1257        let styles = [
1258            HighlightStyle::Border,
1259            HighlightStyle::Background,
1260            HighlightStyle::Overlay,
1261            HighlightStyle::Glow,
1262        ];
1263        assert_eq!(styles.len(), 4);
1264        assert_eq!(styles[0], HighlightStyle::Border);
1265    }
1266
1267    #[test]
1268    fn test_export_formats() {
1269        let image = ExportFormat::Image(ImageFormat::PNG);
1270        let html = ExportFormat::HTML;
1271        let svg = ExportFormat::SVG;
1272        let data = ExportFormat::Data(DataFormat::JSON);
1273        let video = ExportFormat::Video(VideoFormat::MP4);
1274        assert_eq!(html, ExportFormat::HTML);
1275        assert_eq!(svg, ExportFormat::SVG);
1276        match image {
1277            ExportFormat::Image(ImageFormat::PNG) => {}
1278            _ => panic!("Expected PNG image format"),
1279        }
1280        match data {
1281            ExportFormat::Data(DataFormat::JSON) => {}
1282            _ => panic!("Expected JSON data format"),
1283        }
1284        match video {
1285            ExportFormat::Video(VideoFormat::MP4) => {}
1286            _ => panic!("Expected MP4 video format"),
1287        }
1288    }
1289
1290    #[test]
1291    fn test_export_quality_levels() {
1292        let qualities = [
1293            ExportQuality::Low,
1294            ExportQuality::Medium,
1295            ExportQuality::High,
1296            ExportQuality::Maximum,
1297        ];
1298        assert_eq!(qualities.len(), 4);
1299        assert_eq!(qualities[2], ExportQuality::High);
1300    }
1301    #[test]
1302    fn test_cache_operations() {
1303        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(42);
1304        let mut model = Sequential::<f32>::new();
1305        model.add_layer(Dense::new(10, 5, Some("relu"), &mut rng).expect("Operation failed"));
1306        let config = VisualizationConfig::default();
1307        let mut visualizer = AttentionVisualizer::new(model, config);
1308        assert!(visualizer.get_cached_attention("test_layer").is_none());
1309        visualizer.clear_cache();
1310    }
1311
1312    #[test]
1313    fn test_resolution_settings() {
1314        let resolution = Resolution::default();
1315        assert_eq!(resolution.width, 1920);
1316        assert_eq!(resolution.height, 1080);
1317        assert_eq!(resolution.dpi, 300);
1318    }
1319
1320    #[test]
1321    fn test_compression_settings() {
1322        let compression = CompressionSettings::default();
1323        assert!(compression.enabled);
1324        assert_eq!(compression.level, 6);
1325        assert!(!compression.lossless);
1326    }
1327}