scirs2_neural/interpretation/
visualization.rs

1//! Attention and feature visualization for neural network interpretation
2//!
3//! This module provides visualization capabilities for understanding neural network
4//! behavior including attention visualization, feature visualization, and network dissection.
5
6use crate::error::{NeuralError, Result};
7use ndarray::ArrayD;
8use num_traits::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::iter::Sum;
12
13/// Feature visualization method
14#[derive(Debug, Clone, PartialEq)]
15pub enum VisualizationMethod {
16    /// Activation maximization
17    ActivationMaximization {
18        /// Target layer name for activation maximization
19        target_layer: String,
20        /// Specific unit to maximize (None for all)
21        target_unit: Option<usize>,
22        /// Number of optimization iterations
23        num_iterations: usize,
24        /// Learning rate for optimization
25        learning_rate: f64,
26    },
27    /// Deep dream
28    DeepDream {
29        /// Target layer name for deep dream
30        target_layer: String,
31        /// Number of optimization iterations
32        num_iterations: usize,
33        /// Learning rate for optimization
34        learning_rate: f64,
35        /// Factor to amplify activations
36        amplify_factor: f64,
37    },
38    /// Feature inversion
39    FeatureInversion {
40        /// Target layer name for feature inversion
41        target_layer: String,
42        /// Weight for regularization term
43        regularization_weight: f64,
44    },
45    /// Class Activation Mapping (CAM)
46    ClassActivationMapping {
47        /// Target layer for CAM
48        target_layer: String,
49        /// Target class index
50        target_class: usize,
51    },
52    /// Network dissection for concept visualization
53    NetworkDissection {
54        /// Concept dataset for analysis
55        concept_data: Vec<ArrayD<f32>>,
56        /// Concept labels
57        concept_labels: Vec<String>,
58    },
59}
60
61/// Attention aggregation strategy
62#[derive(Debug, Clone, PartialEq)]
63pub enum AttentionAggregation {
64    /// Average across all heads
65    Average,
66    /// Maximum across all heads
67    Maximum,
68    /// Specific head only
69    Head(usize),
70    /// Weighted combination of heads
71    Weighted(Vec<f64>),
72}
73
74/// Attention visualizer for transformer models
75#[derive(Debug, Clone)]
76pub struct AttentionVisualizer<F: Float + Debug> {
77    /// Number of attention heads
78    pub num_heads: usize,
79    /// Sequence length
80    pub sequence_length: usize,
81    /// Aggregation strategy
82    pub aggregation: AttentionAggregation,
83    /// Cached attention weights
84    pub attention_cache: HashMap<String, ArrayD<F>>,
85    /// Layer names to visualize
86    pub target_layers: Vec<String>,
87}
88
89/// Visualization result containing processed data
90#[derive(Debug, Clone)]
91pub struct VisualizationResult<F: Float + Debug> {
92    /// Visualization method used
93    pub method: VisualizationMethod,
94    /// Generated visualization data
95    pub visualization_data: ArrayD<F>,
96    /// Metadata about the visualization
97    pub metadata: HashMap<String, String>,
98    /// Confidence or quality score
99    pub quality_score: f64,
100}
101
102/// Network dissection result
103#[derive(Debug, Clone)]
104pub struct NetworkDissectionResult {
105    /// Layer name analyzed
106    pub layer_name: String,
107    /// Detected concepts and their selectivity scores
108    pub concept_selectivity: HashMap<String, f64>,
109    /// Number of units analyzed
110    pub num_units: usize,
111    /// Coverage of concepts across units
112    pub concept_coverage: HashMap<String, usize>,
113}
114
115impl<F> AttentionVisualizer<F>
116where
117    F: Float
118        + Debug
119        + 'static
120        + ndarray::ScalarOperand
121        + num_traits::FromPrimitive
122        + Sum
123        + Clone
124        + Copy,
125{
126    /// Create a new attention visualizer
127    pub fn new(
128        num_heads: usize,
129        sequence_length: usize,
130        aggregation: AttentionAggregation,
131        target_layers: Vec<String>,
132    ) -> Self {
133        Self {
134            num_heads,
135            sequence_length,
136            aggregation,
137            attention_cache: HashMap::new(),
138            target_layers,
139        }
140    }
141
142    /// Cache attention weights for a layer
143    pub fn cache_attention_weights(&mut self, layer_name: String, attention_weights: ArrayD<F>) {
144        self.attention_cache.insert(layer_name, attention_weights);
145    }
146
147    /// Visualize attention patterns
148    pub fn visualize_attention(&self, layer_name: &str) -> Result<ArrayD<F>> {
149        let attention_weights = self.attention_cache.get(layer_name).ok_or_else(|| {
150            NeuralError::ComputationError(format!(
151                "No attention weights cached for layer: {}",
152                layer_name
153            ))
154        })?;
155
156        self.aggregate_attention_heads(attention_weights)
157    }
158
159    /// Aggregate attention across multiple heads
160    pub fn aggregate_attention_heads(&self, attention_weights: &ArrayD<F>) -> Result<ArrayD<F>> {
161        match &self.aggregation {
162            AttentionAggregation::Average => {
163                // Average across head dimension (assuming shape: [batch, heads, seq, seq])
164                if attention_weights.ndim() >= 4 {
165                    Ok(attention_weights.mean_axis(ndarray::Axis(1)).unwrap())
166                } else {
167                    Ok(attention_weights.clone())
168                }
169            }
170            AttentionAggregation::Maximum => {
171                // Maximum across head dimension
172                if attention_weights.ndim() >= 4 {
173                    let max_attention = attention_weights.fold_axis(
174                        ndarray::Axis(1),
175                        F::neg_infinity(),
176                        |&acc, &x| acc.max(x),
177                    );
178                    Ok(max_attention)
179                } else {
180                    Ok(attention_weights.clone())
181                }
182            }
183            AttentionAggregation::Head(head_idx) => {
184                // Select specific head
185                if attention_weights.ndim() >= 4 && *head_idx < self.num_heads {
186                    Ok(attention_weights
187                        .index_axis(ndarray::Axis(1), *head_idx)
188                        .to_owned())
189                } else {
190                    Err(NeuralError::InvalidArchitecture(format!(
191                        "Invalid head index {} for {} heads",
192                        head_idx, self.num_heads
193                    )))
194                }
195            }
196            AttentionAggregation::Weighted(weights) => {
197                // Weighted combination of heads
198                if weights.len() != self.num_heads {
199                    return Err(NeuralError::InvalidArchitecture(
200                        "Number of weights must match number of heads".to_string(),
201                    ));
202                }
203
204                if attention_weights.ndim() >= 4 {
205                    let mut weighted_attention =
206                        attention_weights.index_axis(ndarray::Axis(1), 0).to_owned()
207                            * F::from(weights[0]).unwrap();
208
209                    for (i, &weight) in weights.iter().enumerate().skip(1) {
210                        let head_attention =
211                            attention_weights.index_axis(ndarray::Axis(1), i).to_owned();
212                        weighted_attention =
213                            weighted_attention + head_attention * F::from(weight).unwrap();
214                    }
215
216                    Ok(weighted_attention)
217                } else {
218                    Ok(attention_weights.clone())
219                }
220            }
221        }
222    }
223
224    /// Generate attention rollout visualization
225    pub fn attention_rollout(&self) -> Result<ArrayD<F>> {
226        // Simplified attention rollout - would normally compute across all layers
227        if self.attention_cache.is_empty() {
228            return Err(NeuralError::ComputationError(
229                "No attention weights available for rollout".to_string(),
230            ));
231        }
232
233        // For now, just return the first cached attention
234        let first_attention = self.attention_cache.values().next().unwrap();
235        self.aggregate_attention_heads(first_attention)
236    }
237
238    /// Visualize attention flow between tokens
239    pub fn visualize_attention_flow(
240        &self,
241        layer_name: &str,
242        token_indices: &[usize],
243    ) -> Result<Vec<f64>> {
244        let attention = self.visualize_attention(layer_name)?;
245
246        let mut flow_scores = Vec::new();
247
248        for &token_idx in token_indices {
249            if token_idx < self.sequence_length {
250                // Compute attention flow for this token
251                let token_attention = attention.index_axis(ndarray::Axis(1), token_idx);
252                let flow_score = token_attention.sum().to_f64().unwrap_or(0.0);
253                flow_scores.push(flow_score);
254            } else {
255                flow_scores.push(0.0);
256            }
257        }
258
259        Ok(flow_scores)
260    }
261}
262
263/// Generate feature visualization using specified method
264pub fn generate_feature_visualization<F>(
265    method: &VisualizationMethod,
266    input_shape: &[usize],
267) -> Result<VisualizationResult<F>>
268where
269    F: Float
270        + Debug
271        + 'static
272        + ndarray::ScalarOperand
273        + num_traits::FromPrimitive
274        + Sum
275        + Clone
276        + Copy,
277{
278    match method {
279        VisualizationMethod::ActivationMaximization {
280            target_layer,
281            target_unit,
282            num_iterations,
283            learning_rate,
284        } => {
285            // Simplified activation maximization
286            let mut optimized_input = ndarray::Array::zeros(input_shape).into_dyn();
287
288            for _iter in 0..*num_iterations {
289                // Apply gradient ascent (simplified)
290                optimized_input = optimized_input
291                    .mapv(|x| x + F::from(*learning_rate * rand::random::<f64>()).unwrap());
292            }
293
294            let mut metadata = HashMap::new();
295            metadata.insert("target_layer".to_string(), target_layer.clone());
296            metadata.insert("iterations".to_string(), num_iterations.to_string());
297            if let Some(unit) = target_unit {
298                metadata.insert("target_unit".to_string(), unit.to_string());
299            }
300
301            Ok(VisualizationResult {
302                method: method.clone(),
303                visualization_data: optimized_input,
304                metadata,
305                quality_score: 0.8,
306            })
307        }
308        VisualizationMethod::DeepDream {
309            target_layer,
310            num_iterations,
311            learning_rate,
312            amplify_factor,
313        } => {
314            // Simplified deep dream implementation
315            let mut dream_input = ndarray::Array::ones(input_shape).into_dyn();
316
317            for _iter in 0..*num_iterations {
318                // Amplify activations (simplified)
319                dream_input = dream_input.mapv(|x| {
320                    x * F::from(*amplify_factor).unwrap()
321                        + F::from(*learning_rate * rand::random::<f64>()).unwrap()
322                });
323            }
324
325            let mut metadata = HashMap::new();
326            metadata.insert("target_layer".to_string(), target_layer.clone());
327            metadata.insert("iterations".to_string(), num_iterations.to_string());
328            metadata.insert("amplify_factor".to_string(), amplify_factor.to_string());
329
330            Ok(VisualizationResult {
331                method: method.clone(),
332                visualization_data: dream_input,
333                metadata,
334                quality_score: 0.7,
335            })
336        }
337        VisualizationMethod::FeatureInversion {
338            target_layer,
339            regularization_weight,
340        } => {
341            // Simplified feature inversion
342            let inverted_input = ndarray::Array::zeros(input_shape).into_dyn();
343
344            let mut metadata = HashMap::new();
345            metadata.insert("target_layer".to_string(), target_layer.clone());
346            metadata.insert(
347                "regularization".to_string(),
348                regularization_weight.to_string(),
349            );
350
351            Ok(VisualizationResult {
352                method: method.clone(),
353                visualization_data: inverted_input,
354                metadata,
355                quality_score: 0.6,
356            })
357        }
358        VisualizationMethod::ClassActivationMapping {
359            target_layer,
360            target_class,
361        } => {
362            // Simplified CAM
363            let cam_result = ndarray::Array::ones(input_shape).into_dyn();
364
365            let mut metadata = HashMap::new();
366            metadata.insert("target_layer".to_string(), target_layer.clone());
367            metadata.insert("target_class".to_string(), target_class.to_string());
368
369            Ok(VisualizationResult {
370                method: method.clone(),
371                visualization_data: cam_result,
372                metadata,
373                quality_score: 0.85,
374            })
375        }
376        VisualizationMethod::NetworkDissection {
377            concept_data,
378            concept_labels,
379        } => {
380            // Simplified network dissection
381            let dissection_result = ndarray::Array::zeros(input_shape).into_dyn();
382
383            let mut metadata = HashMap::new();
384            metadata.insert("num_concepts".to_string(), concept_labels.len().to_string());
385            metadata.insert("num_examples".to_string(), concept_data.len().to_string());
386
387            Ok(VisualizationResult {
388                method: method.clone(),
389                visualization_data: dissection_result,
390                metadata,
391                quality_score: 0.75,
392            })
393        }
394    }
395}
396
397/// Perform network dissection analysis
398pub fn perform_network_dissection(
399    layer_name: String,
400    layer_activations: &ArrayD<f32>,
401    concept_data: &[ArrayD<f32>],
402    concept_labels: &[String],
403) -> Result<NetworkDissectionResult> {
404    if concept_data.len() != concept_labels.len() {
405        return Err(NeuralError::InvalidArchitecture(
406            "Number of concept examples must match number of labels".to_string(),
407        ));
408    }
409
410    let mut concept_selectivity = HashMap::new();
411    let mut concept_coverage = HashMap::new();
412
413    // Simplified network dissection
414    for (concept_example, concept_label) in concept_data.iter().zip(concept_labels.iter()) {
415        // Compute selectivity score (simplified correlation)
416        let selectivity = if layer_activations.len() == concept_example.len() {
417            let correlation = layer_activations
418                .iter()
419                .zip(concept_example.iter())
420                .map(|(&a, &b)| (a as f64) * (b as f64))
421                .sum::<f64>()
422                / layer_activations.len() as f64;
423            correlation.abs()
424        } else {
425            0.0
426        };
427
428        concept_selectivity.insert(concept_label.clone(), selectivity);
429
430        // Count units that respond to this concept
431        let responsive_units = layer_activations
432            .iter()
433            .zip(concept_example.iter())
434            .filter(|(&a, &b)| (a as f64) * (b as f64) > 0.5)
435            .count();
436
437        concept_coverage.insert(concept_label.clone(), responsive_units);
438    }
439
440    Ok(NetworkDissectionResult {
441        layer_name,
442        concept_selectivity,
443        num_units: layer_activations.len(),
444        concept_coverage,
445    })
446}
447
448/// Create attention heatmap for visualization
449pub fn create_attention_heatmap<F>(
450    attention_weights: &ArrayD<F>,
451    token_labels: &[String],
452) -> Result<Vec<Vec<f64>>>
453where
454    F: Float
455        + Debug
456        + 'static
457        + ndarray::ScalarOperand
458        + num_traits::FromPrimitive
459        + Sum
460        + Clone
461        + Copy,
462{
463    if attention_weights.ndim() < 2 {
464        return Err(NeuralError::InvalidArchitecture(
465            "Attention weights must be at least 2D".to_string(),
466        ));
467    }
468
469    let shape = attention_weights.shape();
470    let seq_len = shape[shape.len() - 1];
471
472    if token_labels.len() != seq_len {
473        return Err(NeuralError::InvalidArchitecture(
474            "Number of token labels must match sequence length".to_string(),
475        ));
476    }
477
478    let mut heatmap = Vec::new();
479
480    for i in 0..seq_len {
481        let mut row = Vec::new();
482        for j in 0..seq_len {
483            // Get attention weight for position (i, j)
484            let weight = if attention_weights.ndim() == 2 {
485                attention_weights[[i, j]].to_f64().unwrap_or(0.0)
486            } else {
487                // For higher dimensions, simplified access - just use 0.5 as placeholder
488                // In a real implementation, this would properly handle multi-dimensional attention
489                0.5
490            };
491            row.push(weight);
492        }
493        heatmap.push(row);
494    }
495
496    Ok(heatmap)
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    use ndarray::Array;
503
504    #[test]
505    fn test_attention_visualizer_creation() {
506        let visualizer = AttentionVisualizer::<f64>::new(
507            8,
508            512,
509            AttentionAggregation::Average,
510            vec!["layer1".to_string(), "layer2".to_string()],
511        );
512
513        assert_eq!(visualizer.num_heads, 8);
514        assert_eq!(visualizer.sequence_length, 512);
515        assert_eq!(visualizer.target_layers.len(), 2);
516    }
517
518    #[test]
519    fn test_attention_aggregation() {
520        let mut visualizer = AttentionVisualizer::<f64>::new(
521            2,
522            4,
523            AttentionAggregation::Average,
524            vec!["test".to_string()],
525        );
526
527        // Create mock attention weights: [batch=1, heads=2, seq=4, seq=4]
528        let attention = Array::ones((1, 2, 4, 4)).into_dyn();
529        visualizer.cache_attention_weights("test".to_string(), attention);
530
531        let aggregated = visualizer.visualize_attention("test");
532        assert!(aggregated.is_ok());
533    }
534
535    #[test]
536    fn test_feature_visualization() {
537        let method = VisualizationMethod::ActivationMaximization {
538            target_layer: "conv1".to_string(),
539            target_unit: Some(5),
540            num_iterations: 100,
541            learning_rate: 0.01,
542        };
543
544        let result = generate_feature_visualization::<f64>(&method, &[3, 32, 32]);
545        assert!(result.is_ok());
546
547        let viz_result = result.unwrap();
548        assert_eq!(viz_result.visualization_data.shape(), &[3, 32, 32]);
549        assert!(viz_result.metadata.contains_key("target_layer"));
550    }
551
552    #[test]
553    fn test_network_dissection() {
554        let layer_activations = Array::from_vec(vec![0.5, 0.8, 0.3, 0.9]).into_dyn();
555        let concept_data = vec![
556            Array::from_vec(vec![0.4, 0.7, 0.2, 0.8]).into_dyn(),
557            Array::from_vec(vec![0.6, 0.9, 0.4, 1.0]).into_dyn(),
558        ];
559        let concept_labels = vec!["dog".to_string(), "car".to_string()];
560
561        let result = perform_network_dissection(
562            "conv5".to_string(),
563            &layer_activations,
564            &concept_data,
565            &concept_labels,
566        );
567
568        assert!(result.is_ok());
569        let dissection = result.unwrap();
570        assert_eq!(dissection.layer_name, "conv5");
571        assert_eq!(dissection.concept_selectivity.len(), 2);
572    }
573
574    #[test]
575    fn test_attention_heatmap() {
576        let attention = Array::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4])
577            .unwrap()
578            .into_dyn();
579        let tokens = vec!["hello".to_string(), "world".to_string()];
580
581        let heatmap = create_attention_heatmap(&attention, &tokens);
582        assert!(heatmap.is_ok());
583
584        let heatmap_data = heatmap.unwrap();
585        assert_eq!(heatmap_data.len(), 2);
586        assert_eq!(heatmap_data[0].len(), 2);
587    }
588}