1use crate::error::{NeuralError, Result};
7use ndarray::ArrayD;
8use num_traits::Float;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::iter::Sum;
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum VisualizationMethod {
16 ActivationMaximization {
18 target_layer: String,
20 target_unit: Option<usize>,
22 num_iterations: usize,
24 learning_rate: f64,
26 },
27 DeepDream {
29 target_layer: String,
31 num_iterations: usize,
33 learning_rate: f64,
35 amplify_factor: f64,
37 },
38 FeatureInversion {
40 target_layer: String,
42 regularization_weight: f64,
44 },
45 ClassActivationMapping {
47 target_layer: String,
49 target_class: usize,
51 },
52 NetworkDissection {
54 concept_data: Vec<ArrayD<f32>>,
56 concept_labels: Vec<String>,
58 },
59}
60
61#[derive(Debug, Clone, PartialEq)]
63pub enum AttentionAggregation {
64 Average,
66 Maximum,
68 Head(usize),
70 Weighted(Vec<f64>),
72}
73
74#[derive(Debug, Clone)]
76pub struct AttentionVisualizer<F: Float + Debug> {
77 pub num_heads: usize,
79 pub sequence_length: usize,
81 pub aggregation: AttentionAggregation,
83 pub attention_cache: HashMap<String, ArrayD<F>>,
85 pub target_layers: Vec<String>,
87}
88
89#[derive(Debug, Clone)]
91pub struct VisualizationResult<F: Float + Debug> {
92 pub method: VisualizationMethod,
94 pub visualization_data: ArrayD<F>,
96 pub metadata: HashMap<String, String>,
98 pub quality_score: f64,
100}
101
102#[derive(Debug, Clone)]
104pub struct NetworkDissectionResult {
105 pub layer_name: String,
107 pub concept_selectivity: HashMap<String, f64>,
109 pub num_units: usize,
111 pub concept_coverage: HashMap<String, usize>,
113}
114
115impl<F> AttentionVisualizer<F>
116where
117 F: Float
118 + Debug
119 + 'static
120 + ndarray::ScalarOperand
121 + num_traits::FromPrimitive
122 + Sum
123 + Clone
124 + Copy,
125{
126 pub fn new(
128 num_heads: usize,
129 sequence_length: usize,
130 aggregation: AttentionAggregation,
131 target_layers: Vec<String>,
132 ) -> Self {
133 Self {
134 num_heads,
135 sequence_length,
136 aggregation,
137 attention_cache: HashMap::new(),
138 target_layers,
139 }
140 }
141
142 pub fn cache_attention_weights(&mut self, layer_name: String, attention_weights: ArrayD<F>) {
144 self.attention_cache.insert(layer_name, attention_weights);
145 }
146
147 pub fn visualize_attention(&self, layer_name: &str) -> Result<ArrayD<F>> {
149 let attention_weights = self.attention_cache.get(layer_name).ok_or_else(|| {
150 NeuralError::ComputationError(format!(
151 "No attention weights cached for layer: {}",
152 layer_name
153 ))
154 })?;
155
156 self.aggregate_attention_heads(attention_weights)
157 }
158
159 pub fn aggregate_attention_heads(&self, attention_weights: &ArrayD<F>) -> Result<ArrayD<F>> {
161 match &self.aggregation {
162 AttentionAggregation::Average => {
163 if attention_weights.ndim() >= 4 {
165 Ok(attention_weights.mean_axis(ndarray::Axis(1)).unwrap())
166 } else {
167 Ok(attention_weights.clone())
168 }
169 }
170 AttentionAggregation::Maximum => {
171 if attention_weights.ndim() >= 4 {
173 let max_attention = attention_weights.fold_axis(
174 ndarray::Axis(1),
175 F::neg_infinity(),
176 |&acc, &x| acc.max(x),
177 );
178 Ok(max_attention)
179 } else {
180 Ok(attention_weights.clone())
181 }
182 }
183 AttentionAggregation::Head(head_idx) => {
184 if attention_weights.ndim() >= 4 && *head_idx < self.num_heads {
186 Ok(attention_weights
187 .index_axis(ndarray::Axis(1), *head_idx)
188 .to_owned())
189 } else {
190 Err(NeuralError::InvalidArchitecture(format!(
191 "Invalid head index {} for {} heads",
192 head_idx, self.num_heads
193 )))
194 }
195 }
196 AttentionAggregation::Weighted(weights) => {
197 if weights.len() != self.num_heads {
199 return Err(NeuralError::InvalidArchitecture(
200 "Number of weights must match number of heads".to_string(),
201 ));
202 }
203
204 if attention_weights.ndim() >= 4 {
205 let mut weighted_attention =
206 attention_weights.index_axis(ndarray::Axis(1), 0).to_owned()
207 * F::from(weights[0]).unwrap();
208
209 for (i, &weight) in weights.iter().enumerate().skip(1) {
210 let head_attention =
211 attention_weights.index_axis(ndarray::Axis(1), i).to_owned();
212 weighted_attention =
213 weighted_attention + head_attention * F::from(weight).unwrap();
214 }
215
216 Ok(weighted_attention)
217 } else {
218 Ok(attention_weights.clone())
219 }
220 }
221 }
222 }
223
224 pub fn attention_rollout(&self) -> Result<ArrayD<F>> {
226 if self.attention_cache.is_empty() {
228 return Err(NeuralError::ComputationError(
229 "No attention weights available for rollout".to_string(),
230 ));
231 }
232
233 let first_attention = self.attention_cache.values().next().unwrap();
235 self.aggregate_attention_heads(first_attention)
236 }
237
238 pub fn visualize_attention_flow(
240 &self,
241 layer_name: &str,
242 token_indices: &[usize],
243 ) -> Result<Vec<f64>> {
244 let attention = self.visualize_attention(layer_name)?;
245
246 let mut flow_scores = Vec::new();
247
248 for &token_idx in token_indices {
249 if token_idx < self.sequence_length {
250 let token_attention = attention.index_axis(ndarray::Axis(1), token_idx);
252 let flow_score = token_attention.sum().to_f64().unwrap_or(0.0);
253 flow_scores.push(flow_score);
254 } else {
255 flow_scores.push(0.0);
256 }
257 }
258
259 Ok(flow_scores)
260 }
261}
262
263pub fn generate_feature_visualization<F>(
265 method: &VisualizationMethod,
266 input_shape: &[usize],
267) -> Result<VisualizationResult<F>>
268where
269 F: Float
270 + Debug
271 + 'static
272 + ndarray::ScalarOperand
273 + num_traits::FromPrimitive
274 + Sum
275 + Clone
276 + Copy,
277{
278 match method {
279 VisualizationMethod::ActivationMaximization {
280 target_layer,
281 target_unit,
282 num_iterations,
283 learning_rate,
284 } => {
285 let mut optimized_input = ndarray::Array::zeros(input_shape).into_dyn();
287
288 for _iter in 0..*num_iterations {
289 optimized_input = optimized_input
291 .mapv(|x| x + F::from(*learning_rate * rand::random::<f64>()).unwrap());
292 }
293
294 let mut metadata = HashMap::new();
295 metadata.insert("target_layer".to_string(), target_layer.clone());
296 metadata.insert("iterations".to_string(), num_iterations.to_string());
297 if let Some(unit) = target_unit {
298 metadata.insert("target_unit".to_string(), unit.to_string());
299 }
300
301 Ok(VisualizationResult {
302 method: method.clone(),
303 visualization_data: optimized_input,
304 metadata,
305 quality_score: 0.8,
306 })
307 }
308 VisualizationMethod::DeepDream {
309 target_layer,
310 num_iterations,
311 learning_rate,
312 amplify_factor,
313 } => {
314 let mut dream_input = ndarray::Array::ones(input_shape).into_dyn();
316
317 for _iter in 0..*num_iterations {
318 dream_input = dream_input.mapv(|x| {
320 x * F::from(*amplify_factor).unwrap()
321 + F::from(*learning_rate * rand::random::<f64>()).unwrap()
322 });
323 }
324
325 let mut metadata = HashMap::new();
326 metadata.insert("target_layer".to_string(), target_layer.clone());
327 metadata.insert("iterations".to_string(), num_iterations.to_string());
328 metadata.insert("amplify_factor".to_string(), amplify_factor.to_string());
329
330 Ok(VisualizationResult {
331 method: method.clone(),
332 visualization_data: dream_input,
333 metadata,
334 quality_score: 0.7,
335 })
336 }
337 VisualizationMethod::FeatureInversion {
338 target_layer,
339 regularization_weight,
340 } => {
341 let inverted_input = ndarray::Array::zeros(input_shape).into_dyn();
343
344 let mut metadata = HashMap::new();
345 metadata.insert("target_layer".to_string(), target_layer.clone());
346 metadata.insert(
347 "regularization".to_string(),
348 regularization_weight.to_string(),
349 );
350
351 Ok(VisualizationResult {
352 method: method.clone(),
353 visualization_data: inverted_input,
354 metadata,
355 quality_score: 0.6,
356 })
357 }
358 VisualizationMethod::ClassActivationMapping {
359 target_layer,
360 target_class,
361 } => {
362 let cam_result = ndarray::Array::ones(input_shape).into_dyn();
364
365 let mut metadata = HashMap::new();
366 metadata.insert("target_layer".to_string(), target_layer.clone());
367 metadata.insert("target_class".to_string(), target_class.to_string());
368
369 Ok(VisualizationResult {
370 method: method.clone(),
371 visualization_data: cam_result,
372 metadata,
373 quality_score: 0.85,
374 })
375 }
376 VisualizationMethod::NetworkDissection {
377 concept_data,
378 concept_labels,
379 } => {
380 let dissection_result = ndarray::Array::zeros(input_shape).into_dyn();
382
383 let mut metadata = HashMap::new();
384 metadata.insert("num_concepts".to_string(), concept_labels.len().to_string());
385 metadata.insert("num_examples".to_string(), concept_data.len().to_string());
386
387 Ok(VisualizationResult {
388 method: method.clone(),
389 visualization_data: dissection_result,
390 metadata,
391 quality_score: 0.75,
392 })
393 }
394 }
395}
396
397pub fn perform_network_dissection(
399 layer_name: String,
400 layer_activations: &ArrayD<f32>,
401 concept_data: &[ArrayD<f32>],
402 concept_labels: &[String],
403) -> Result<NetworkDissectionResult> {
404 if concept_data.len() != concept_labels.len() {
405 return Err(NeuralError::InvalidArchitecture(
406 "Number of concept examples must match number of labels".to_string(),
407 ));
408 }
409
410 let mut concept_selectivity = HashMap::new();
411 let mut concept_coverage = HashMap::new();
412
413 for (concept_example, concept_label) in concept_data.iter().zip(concept_labels.iter()) {
415 let selectivity = if layer_activations.len() == concept_example.len() {
417 let correlation = layer_activations
418 .iter()
419 .zip(concept_example.iter())
420 .map(|(&a, &b)| (a as f64) * (b as f64))
421 .sum::<f64>()
422 / layer_activations.len() as f64;
423 correlation.abs()
424 } else {
425 0.0
426 };
427
428 concept_selectivity.insert(concept_label.clone(), selectivity);
429
430 let responsive_units = layer_activations
432 .iter()
433 .zip(concept_example.iter())
434 .filter(|(&a, &b)| (a as f64) * (b as f64) > 0.5)
435 .count();
436
437 concept_coverage.insert(concept_label.clone(), responsive_units);
438 }
439
440 Ok(NetworkDissectionResult {
441 layer_name,
442 concept_selectivity,
443 num_units: layer_activations.len(),
444 concept_coverage,
445 })
446}
447
448pub fn create_attention_heatmap<F>(
450 attention_weights: &ArrayD<F>,
451 token_labels: &[String],
452) -> Result<Vec<Vec<f64>>>
453where
454 F: Float
455 + Debug
456 + 'static
457 + ndarray::ScalarOperand
458 + num_traits::FromPrimitive
459 + Sum
460 + Clone
461 + Copy,
462{
463 if attention_weights.ndim() < 2 {
464 return Err(NeuralError::InvalidArchitecture(
465 "Attention weights must be at least 2D".to_string(),
466 ));
467 }
468
469 let shape = attention_weights.shape();
470 let seq_len = shape[shape.len() - 1];
471
472 if token_labels.len() != seq_len {
473 return Err(NeuralError::InvalidArchitecture(
474 "Number of token labels must match sequence length".to_string(),
475 ));
476 }
477
478 let mut heatmap = Vec::new();
479
480 for i in 0..seq_len {
481 let mut row = Vec::new();
482 for j in 0..seq_len {
483 let weight = if attention_weights.ndim() == 2 {
485 attention_weights[[i, j]].to_f64().unwrap_or(0.0)
486 } else {
487 0.5
490 };
491 row.push(weight);
492 }
493 heatmap.push(row);
494 }
495
496 Ok(heatmap)
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 use ndarray::Array;
503
504 #[test]
505 fn test_attention_visualizer_creation() {
506 let visualizer = AttentionVisualizer::<f64>::new(
507 8,
508 512,
509 AttentionAggregation::Average,
510 vec!["layer1".to_string(), "layer2".to_string()],
511 );
512
513 assert_eq!(visualizer.num_heads, 8);
514 assert_eq!(visualizer.sequence_length, 512);
515 assert_eq!(visualizer.target_layers.len(), 2);
516 }
517
518 #[test]
519 fn test_attention_aggregation() {
520 let mut visualizer = AttentionVisualizer::<f64>::new(
521 2,
522 4,
523 AttentionAggregation::Average,
524 vec!["test".to_string()],
525 );
526
527 let attention = Array::ones((1, 2, 4, 4)).into_dyn();
529 visualizer.cache_attention_weights("test".to_string(), attention);
530
531 let aggregated = visualizer.visualize_attention("test");
532 assert!(aggregated.is_ok());
533 }
534
535 #[test]
536 fn test_feature_visualization() {
537 let method = VisualizationMethod::ActivationMaximization {
538 target_layer: "conv1".to_string(),
539 target_unit: Some(5),
540 num_iterations: 100,
541 learning_rate: 0.01,
542 };
543
544 let result = generate_feature_visualization::<f64>(&method, &[3, 32, 32]);
545 assert!(result.is_ok());
546
547 let viz_result = result.unwrap();
548 assert_eq!(viz_result.visualization_data.shape(), &[3, 32, 32]);
549 assert!(viz_result.metadata.contains_key("target_layer"));
550 }
551
552 #[test]
553 fn test_network_dissection() {
554 let layer_activations = Array::from_vec(vec![0.5, 0.8, 0.3, 0.9]).into_dyn();
555 let concept_data = vec![
556 Array::from_vec(vec![0.4, 0.7, 0.2, 0.8]).into_dyn(),
557 Array::from_vec(vec![0.6, 0.9, 0.4, 1.0]).into_dyn(),
558 ];
559 let concept_labels = vec!["dog".to_string(), "car".to_string()];
560
561 let result = perform_network_dissection(
562 "conv5".to_string(),
563 &layer_activations,
564 &concept_data,
565 &concept_labels,
566 );
567
568 assert!(result.is_ok());
569 let dissection = result.unwrap();
570 assert_eq!(dissection.layer_name, "conv5");
571 assert_eq!(dissection.concept_selectivity.len(), 2);
572 }
573
574 #[test]
575 fn test_attention_heatmap() {
576 let attention = Array::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4])
577 .unwrap()
578 .into_dyn();
579 let tokens = vec!["hello".to_string(), "world".to_string()];
580
581 let heatmap = create_attention_heatmap(&attention, &tokens);
582 assert!(heatmap.is_ok());
583
584 let heatmap_data = heatmap.unwrap();
585 assert_eq!(heatmap_data.len(), 2);
586 assert_eq!(heatmap_data[0].len(), 2);
587 }
588}