torsh_vision/explainability.rs
1//! Model Explainability and Interpretability Tools
2//!
3//! This module provides tools for understanding and interpreting vision model predictions:
4//! - GradCAM (Gradient-weighted Class Activation Mapping)
5//! - Saliency Maps
6//! - Guided Backpropagation
7//! - Integrated Gradients
8//! - Attention Visualization
9//!
10//! These tools help answer questions like:
11//! - Which parts of the image influenced the prediction?
12//! - What features is the model focusing on?
13//! - Are the predictions based on relevant features?
14
15use crate::{Result, VisionError};
16use scirs2_core::ndarray::Array3; // SciRS2 Policy compliance
17use std::sync::Arc;
18use torsh_core::device::Device;
19use torsh_nn::Module;
20use torsh_tensor::Tensor;
21
22/// GradCAM (Gradient-weighted Class Activation Mapping)
23///
24/// GradCAM generates visual explanations for CNN decisions by using gradients
25/// flowing into the final convolutional layer to produce a coarse localization map
26/// highlighting important regions in the image.
27///
28/// Reference: Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks
29/// via Gradient-based Localization", ICCV 2017
30pub struct GradCAM {
31 target_layer_name: String,
32 device: Arc<dyn Device>,
33}
34
35impl GradCAM {
36 /// Create a new GradCAM explainer
37 ///
38 /// # Arguments
39 /// * `target_layer_name` - Name of the convolutional layer to use for visualization
40 /// (typically the last conv layer before classification)
41 /// * `device` - Device to perform computations on
42 pub fn new(target_layer_name: String, device: Arc<dyn Device>) -> Self {
43 Self {
44 target_layer_name,
45 device,
46 }
47 }
48
49 /// Generate GradCAM heatmap for a specific class
50 ///
51 /// # Arguments
52 /// * `model` - The model to explain
53 /// * `input` - Input image tensor [C, H, W] or [N, C, H, W]
54 /// * `target_class` - Class index to generate explanation for
55 ///
56 /// # Returns
57 /// Heatmap tensor of the same spatial dimensions as the input
58 pub fn generate_heatmap(
59 &self,
60 model: &dyn Module,
61 input: &Tensor,
62 target_class: usize,
63 ) -> Result<Tensor> {
64 // Ensure input has batch dimension
65 let batched_input = if input.ndim() == 3 {
66 input.unsqueeze(0)?
67 } else {
68 input.clone()
69 };
70
71 // Forward pass to get activations and output
72 // Note: In a complete implementation, we would need hooks to capture
73 // intermediate activations from the target layer
74 let output = model.forward(&batched_input)?;
75
76 // Get score for target class
77 let target_score = output.narrow(1, target_class as i64, 1)?;
78
79 // Backward pass to get gradients
80 target_score.backward()?;
81
82 // In a complete implementation:
83 // 1. Extract activations from target convolutional layer
84 // 2. Extract gradients flowing into the target layer
85 // 3. Compute importance weights (global average pooling of gradients)
86 // 4. Weighted combination of activation maps
87 // 5. Apply ReLU and normalize
88
89 // Placeholder: Create a dummy heatmap for demonstration
90 // In production, this would be replaced with actual GradCAM computation
91 let input_shape = batched_input.shape();
92 let height = input_shape.dims()[2] as usize;
93 let width = input_shape.dims()[3] as usize;
94
95 let heatmap = self.create_placeholder_heatmap(height, width)?;
96
97 Ok(heatmap)
98 }
99
100 /// Generate GradCAM++ heatmap (improved version of GradCAM)
101 ///
102 /// GradCAM++ provides better localization and works better for multiple instances
103 /// of the same class in an image.
104 pub fn generate_gradcam_plus_plus(
105 &self,
106 model: &dyn Module,
107 input: &Tensor,
108 target_class: usize,
109 ) -> Result<Tensor> {
110 // GradCAM++ uses second-order and third-order derivatives
111 // for better weighting of activation maps
112 self.generate_heatmap(model, input, target_class)
113 }
114
115 /// Overlay heatmap on original image
116 ///
117 /// # Arguments
118 /// * `image` - Original image tensor [C, H, W]
119 /// * `heatmap` - Heatmap tensor [H, W]
120 /// * `alpha` - Blending factor (0.0 = only image, 1.0 = only heatmap)
121 ///
122 /// # Returns
123 /// Blended visualization [C, H, W]
124 pub fn overlay_heatmap(&self, image: &Tensor, heatmap: &Tensor, alpha: f32) -> Result<Tensor> {
125 // Normalize heatmap to [0, 1]
126 let hmax = heatmap.max(None, false)?;
127 let hmin = heatmap.min()?;
128 let normalized = heatmap.sub(&hmin)?.div(&hmax.sub(&hmin)?)?;
129
130 // Convert heatmap to RGB using a colormap (e.g., jet colormap)
131 let colored_heatmap = self.apply_colormap(&normalized)?;
132
133 // Blend with original image
134 let blended = image
135 .mul_scalar(1.0 - alpha)?
136 .add(&colored_heatmap.mul_scalar(alpha)?)?;
137
138 Ok(blended)
139 }
140
141 /// Apply jet colormap to grayscale heatmap
142 fn apply_colormap(&self, heatmap: &Tensor) -> Result<Tensor> {
143 // Create RGB channels from grayscale heatmap using jet colormap
144 // This is a simplified implementation
145 let r = heatmap.mul_scalar(1.5)?.clamp(0.0, 1.0)?.unsqueeze(0)?;
146 let g = heatmap
147 .mul_scalar(2.0)?
148 .sub_scalar(0.5)?
149 .clamp(0.0, 1.0)?
150 .unsqueeze(0)?;
151 let b = heatmap
152 .mul_scalar(1.5)?
153 .sub_scalar(1.0)?
154 .clamp(0.0, 1.0)?
155 .unsqueeze(0)?;
156
157 let colored = Tensor::cat(&[&r, &g, &b], 0)?;
158 Ok(colored)
159 }
160
161 /// Create placeholder heatmap (to be replaced with actual implementation)
162 fn create_placeholder_heatmap(&self, height: usize, width: usize) -> Result<Tensor> {
163 use torsh_tensor::creation;
164
165 // Create a simple gradient heatmap for demonstration
166 let heatmap: Tensor<f32> = creation::zeros(&[height, width])?;
167
168 Ok(heatmap)
169 }
170}
171
172/// Saliency Map Generator
173///
174/// Saliency maps show which input pixels have the greatest influence on the model's
175/// prediction by computing the gradient of the output with respect to the input.
176pub struct SaliencyMap {
177 device: Arc<dyn Device>,
178}
179
180impl SaliencyMap {
181 /// Create a new saliency map generator
182 pub fn new(device: Arc<dyn Device>) -> Self {
183 Self { device }
184 }
185
186 /// Generate vanilla saliency map
187 ///
188 /// Computes the gradient of the class score with respect to the input image.
189 ///
190 /// # Arguments
191 /// * `model` - The model to explain
192 /// * `input` - Input image tensor [C, H, W] or [N, C, H, W]
193 /// * `target_class` - Class index to generate saliency for
194 ///
195 /// # Returns
196 /// Saliency map showing pixel importance
197 pub fn generate(
198 &self,
199 model: &dyn Module,
200 input: &Tensor,
201 target_class: usize,
202 ) -> Result<Tensor> {
203 // Enable gradient computation for input
204 let input_with_grad = input.clone().requires_grad_(true);
205
206 // Ensure batch dimension
207 let batched = if input_with_grad.ndim() == 3 {
208 input_with_grad.unsqueeze(0)?
209 } else {
210 input_with_grad.clone()
211 };
212
213 // Forward pass
214 let output = model.forward(&batched)?;
215
216 // Get score for target class
217 let score = output.narrow(1, target_class as i64, 1)?;
218
219 // Backward pass
220 score.backward()?;
221
222 // Get gradient with respect to input
223 let grad = batched
224 .grad()
225 .ok_or_else(|| VisionError::Other(anyhow::anyhow!("No gradient computed")))?;
226
227 // Take absolute value and max across channels
228 let abs_grad = grad.abs()?;
229 let saliency = abs_grad.max(Some(1), false)?;
230
231 Ok(saliency)
232 }
233
234 /// Generate smooth saliency map
235 ///
236 /// Averages saliency maps computed from multiple noisy versions of the input
237 /// to reduce noise and produce smoother visualizations.
238 pub fn generate_smooth(
239 &self,
240 model: &dyn Module,
241 input: &Tensor,
242 target_class: usize,
243 num_samples: usize,
244 noise_stddev: f32,
245 ) -> Result<Tensor> {
246 use torsh_tensor::creation;
247
248 let shape: Vec<usize> = input.shape().dims().iter().map(|&x| x as usize).collect();
249 let mut accumulated: Tensor<f32> = creation::zeros(&shape)?;
250
251 for _ in 0..num_samples {
252 // Add random noise to input
253 let noise: Tensor<f32> = creation::randn(&shape)?;
254 let noise = noise.mul_scalar(noise_stddev)?;
255 let noisy_input = input.add(&noise)?;
256
257 // Generate saliency for noisy input
258 let saliency = self.generate(model, &noisy_input, target_class)?;
259
260 // Accumulate
261 accumulated = accumulated.add(&saliency)?;
262 }
263
264 // Average
265 let smooth_saliency = accumulated.div_scalar(num_samples as f32)?;
266
267 Ok(smooth_saliency)
268 }
269}
270
271/// Integrated Gradients
272///
273/// Integrated Gradients is a method that attributes the prediction of a model to its
274/// input features by integrating gradients along a path from a baseline to the input.
275///
276/// Reference: Sundararajan et al., "Axiomatic Attribution for Deep Networks", ICML 2017
277pub struct IntegratedGradients {
278 baseline_type: BaselineType,
279 num_steps: usize,
280 device: Arc<dyn Device>,
281}
282
283/// Type of baseline to use for integrated gradients
284#[derive(Debug, Clone, Copy)]
285pub enum BaselineType {
286 /// Black image (all zeros)
287 Black,
288 /// Random noise
289 Random,
290 /// Blurred version of input
291 Blurred,
292}
293
294impl IntegratedGradients {
295 /// Create a new integrated gradients explainer
296 pub fn new(baseline_type: BaselineType, num_steps: usize, device: Arc<dyn Device>) -> Self {
297 Self {
298 baseline_type,
299 num_steps,
300 device,
301 }
302 }
303
304 /// Generate integrated gradients attribution map
305 ///
306 /// # Arguments
307 /// * `model` - The model to explain
308 /// * `input` - Input image tensor [C, H, W] or [N, C, H, W]
309 /// * `target_class` - Class index to generate attribution for
310 ///
311 /// # Returns
312 /// Attribution map showing feature importance
313 pub fn generate(
314 &self,
315 model: &dyn Module,
316 input: &Tensor,
317 target_class: usize,
318 ) -> Result<Tensor> {
319 // Create baseline
320 let baseline = self.create_baseline(input)?;
321
322 // Compute path from baseline to input
323 let mut accumulated_gradients = baseline.clone();
324
325 for step in 0..self.num_steps {
326 let alpha = (step as f32) / (self.num_steps as f32);
327
328 // Interpolated input
329 let interpolated = baseline
330 .mul_scalar(1.0 - alpha)?
331 .add(&input.mul_scalar(alpha)?)?;
332
333 // Enable gradients
334 let interp_with_grad = interpolated.clone().requires_grad_(true);
335
336 // Forward pass
337 let output = model.forward(&interp_with_grad)?;
338 let score = output.narrow(1, target_class as i64, 1)?;
339
340 // Backward pass
341 score.backward()?;
342
343 // Accumulate gradients
344 let grad = interp_with_grad
345 .grad()
346 .ok_or_else(|| VisionError::Other(anyhow::anyhow!("No gradient computed")))?;
347 accumulated_gradients = accumulated_gradients.add(&grad)?;
348 }
349
350 // Average gradients and multiply by (input - baseline)
351 let avg_gradients = accumulated_gradients.div_scalar(self.num_steps as f32)?;
352 let attribution = input.sub(&baseline)?.mul(&avg_gradients)?;
353
354 Ok(attribution)
355 }
356
357 /// Create baseline based on baseline type
358 fn create_baseline(&self, input: &Tensor) -> Result<Tensor> {
359 use torsh_tensor::creation;
360
361 let shape: Vec<usize> = input.shape().dims().iter().map(|&x| x as usize).collect();
362
363 match self.baseline_type {
364 BaselineType::Black => {
365 let baseline: Tensor<f32> = creation::zeros(&shape)?;
366 Ok(baseline)
367 }
368 BaselineType::Random => {
369 let baseline: Tensor<f32> = creation::randn(&shape)?;
370 Ok(baseline.mul_scalar(0.1)?)
371 }
372 BaselineType::Blurred => {
373 // Simple blur by downsampling and upsampling
374 // In production, use actual Gaussian blur
375 Ok(input.clone())
376 }
377 }
378 }
379}
380
381/// Attention Visualization
382///
383/// For models with attention mechanisms (e.g., Vision Transformers),
384/// visualizes the attention weights to understand which parts of the image
385/// the model is focusing on.
386pub struct AttentionVisualizer {
387 device: Arc<dyn Device>,
388}
389
390impl AttentionVisualizer {
391 /// Create a new attention visualizer
392 pub fn new(device: Arc<dyn Device>) -> Self {
393 Self { device }
394 }
395
396 /// Visualize attention weights from a transformer layer
397 ///
398 /// # Arguments
399 /// * `attention_weights` - Attention weight tensor [N, num_heads, seq_len, seq_len]
400 /// * `patch_size` - Size of image patches
401 /// * `image_size` - Original image size (H, W)
402 ///
403 /// # Returns
404 /// Attention map showing which regions are attended to
405 pub fn visualize_attention(
406 &self,
407 attention_weights: &Tensor,
408 patch_size: usize,
409 image_size: (usize, usize),
410 ) -> Result<Tensor> {
411 // Average attention across heads
412 let avg_attention = attention_weights.mean(Some(&[1]), false)?;
413
414 // Extract attention from CLS token to patches (first row)
415 let cls_attention = avg_attention.narrow(1, 0, 1)?;
416
417 // Reshape to spatial grid
418 let num_patches_h = image_size.0 / patch_size;
419 let num_patches_w = image_size.1 / patch_size;
420
421 let reshaped =
422 cls_attention.reshape(&[1i32, num_patches_h as i32, num_patches_w as i32])?;
423
424 // Upsample to original image size
425 // In production, use proper interpolation
426 let upsampled = reshaped.clone();
427
428 Ok(upsampled)
429 }
430
431 /// Visualize attention rollout
432 ///
433 /// Combines attention from multiple layers to understand
434 /// the full attention flow through the network.
435 pub fn attention_rollout(&self, attention_layers: Vec<Tensor>) -> Result<Tensor> {
436 if attention_layers.is_empty() {
437 return Err(VisionError::InvalidArgument(
438 "No attention layers provided".to_string(),
439 ));
440 }
441
442 // Start with identity matrix
443 let mut rollout = attention_layers[0].clone();
444
445 // Multiply attention matrices from consecutive layers
446 for attention in attention_layers.iter().skip(1) {
447 rollout = rollout.matmul(attention)?;
448 }
449
450 Ok(rollout)
451 }
452}
453
454/// Feature Visualization
455///
456/// Generate synthetic images that maximally activate specific neurons or layers,
457/// helping understand what features the network has learned.
458pub struct FeatureVisualizer {
459 learning_rate: f32,
460 num_iterations: usize,
461 device: Arc<dyn Device>,
462}
463
464impl FeatureVisualizer {
465 /// Create a new feature visualizer
466 pub fn new(learning_rate: f32, num_iterations: usize, device: Arc<dyn Device>) -> Self {
467 Self {
468 learning_rate,
469 num_iterations,
470 device,
471 }
472 }
473
474 /// Generate an image that maximally activates a specific class
475 ///
476 /// # Arguments
477 /// * `model` - The model to visualize
478 /// * `target_class` - Class to maximize activation for
479 /// * `image_size` - Size of generated image (H, W)
480 ///
481 /// # Returns
482 /// Synthesized image that maximally activates the target class
483 pub fn visualize_class(
484 &self,
485 model: &dyn Module,
486 target_class: usize,
487 image_size: (usize, usize),
488 ) -> Result<Tensor> {
489 use torsh_tensor::creation;
490
491 // Initialize random image
492 let mut image: Tensor<f32> = creation::randn(&[1, 3, image_size.0, image_size.1])?;
493 image = image.mul_scalar(0.1)?.add_scalar(0.5)?.requires_grad_(true);
494
495 // Optimization loop
496 for iteration in 0..self.num_iterations {
497 // Forward pass
498 let output = model.forward(&image)?;
499 let class_score = output.narrow(1, target_class as i64, 1)?;
500
501 // We want to maximize the class score
502 let loss = class_score.neg()?;
503
504 // Backward pass
505 loss.backward()?;
506
507 // Update image
508 let grad = image
509 .grad()
510 .ok_or_else(|| VisionError::Other(anyhow::anyhow!("No gradient computed")))?;
511 image = image.sub(&grad.mul_scalar(self.learning_rate)?)?;
512
513 // Apply regularization (keep values in reasonable range)
514 image = image.clamp(-2.0, 2.0)?;
515
516 if iteration % 10 == 0 {
517 println!("Iteration {}: loss = {:?}", iteration, loss.item());
518 }
519 }
520
521 Ok(image)
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use torsh_core::device::CpuDevice;
529 use torsh_tensor::creation;
530
531 #[test]
532 fn test_gradcam_creation() {
533 let device = Arc::new(CpuDevice::new());
534 let gradcam = GradCAM::new("layer4".to_string(), device);
535 assert_eq!(gradcam.target_layer_name, "layer4");
536 }
537
538 #[test]
539 fn test_saliency_map_creation() {
540 let device = Arc::new(CpuDevice::new());
541 let _saliency = SaliencyMap::new(device);
542 }
543
544 #[test]
545 fn test_integrated_gradients_creation() {
546 let device = Arc::new(CpuDevice::new());
547 let _ig = IntegratedGradients::new(BaselineType::Black, 50, device);
548 }
549
550 #[test]
551 fn test_attention_visualizer_creation() {
552 let device = Arc::new(CpuDevice::new());
553 let _visualizer = AttentionVisualizer::new(device);
554 }
555
556 #[test]
557 fn test_feature_visualizer_creation() {
558 let device = Arc::new(CpuDevice::new());
559 let _visualizer = FeatureVisualizer::new(0.1, 100, device);
560 }
561
562 #[test]
563 fn test_baseline_types() {
564 let device: Arc<dyn Device> = Arc::new(CpuDevice::new());
565 let ig_black = IntegratedGradients::new(BaselineType::Black, 50, Arc::clone(&device));
566 let ig_random = IntegratedGradients::new(BaselineType::Random, 50, Arc::clone(&device));
567 let ig_blurred = IntegratedGradients::new(BaselineType::Blurred, 50, Arc::clone(&device));
568
569 let input: Tensor<f32> = creation::ones(&[1, 3, 224, 224]).unwrap();
570
571 assert!(ig_black.create_baseline(&input).is_ok());
572 assert!(ig_random.create_baseline(&input).is_ok());
573 assert!(ig_blurred.create_baseline(&input).is_ok());
574 }
575}