Skip to main content

torsh_vision/
explainability.rs

1//! Model Explainability and Interpretability Tools
2//!
3//! This module provides tools for understanding and interpreting vision model predictions:
4//! - GradCAM (Gradient-weighted Class Activation Mapping)
5//! - Saliency Maps
6//! - Guided Backpropagation
7//! - Integrated Gradients
8//! - Attention Visualization
9//!
10//! These tools help answer questions like:
11//! - Which parts of the image influenced the prediction?
12//! - What features is the model focusing on?
13//! - Are the predictions based on relevant features?
14
15use crate::{Result, VisionError};
16use scirs2_core::ndarray::Array3; // SciRS2 Policy compliance
17use std::sync::Arc;
18use torsh_core::device::Device;
19use torsh_nn::Module;
20use torsh_tensor::Tensor;
21
22/// GradCAM (Gradient-weighted Class Activation Mapping)
23///
24/// GradCAM generates visual explanations for CNN decisions by using gradients
25/// flowing into the final convolutional layer to produce a coarse localization map
26/// highlighting important regions in the image.
27///
28/// Reference: Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks
29/// via Gradient-based Localization", ICCV 2017
30pub struct GradCAM {
31    target_layer_name: String,
32    device: Arc<dyn Device>,
33}
34
35impl GradCAM {
36    /// Create a new GradCAM explainer
37    ///
38    /// # Arguments
39    /// * `target_layer_name` - Name of the convolutional layer to use for visualization
40    ///   (typically the last conv layer before classification)
41    /// * `device` - Device to perform computations on
42    pub fn new(target_layer_name: String, device: Arc<dyn Device>) -> Self {
43        Self {
44            target_layer_name,
45            device,
46        }
47    }
48
49    /// Generate GradCAM heatmap for a specific class
50    ///
51    /// # Arguments
52    /// * `model` - The model to explain
53    /// * `input` - Input image tensor [C, H, W] or [N, C, H, W]
54    /// * `target_class` - Class index to generate explanation for
55    ///
56    /// # Returns
57    /// Heatmap tensor of the same spatial dimensions as the input
58    pub fn generate_heatmap(
59        &self,
60        model: &dyn Module,
61        input: &Tensor,
62        target_class: usize,
63    ) -> Result<Tensor> {
64        // Ensure input has batch dimension
65        let batched_input = if input.ndim() == 3 {
66            input.unsqueeze(0)?
67        } else {
68            input.clone()
69        };
70
71        // Forward pass to get activations and output
72        // Note: In a complete implementation, we would need hooks to capture
73        // intermediate activations from the target layer
74        let output = model.forward(&batched_input)?;
75
76        // Get score for target class
77        let target_score = output.narrow(1, target_class as i64, 1)?;
78
79        // Backward pass to get gradients
80        target_score.backward()?;
81
82        // In a complete implementation:
83        // 1. Extract activations from target convolutional layer
84        // 2. Extract gradients flowing into the target layer
85        // 3. Compute importance weights (global average pooling of gradients)
86        // 4. Weighted combination of activation maps
87        // 5. Apply ReLU and normalize
88
89        // Placeholder: Create a dummy heatmap for demonstration
90        // In production, this would be replaced with actual GradCAM computation
91        let input_shape = batched_input.shape();
92        let height = input_shape.dims()[2] as usize;
93        let width = input_shape.dims()[3] as usize;
94
95        let heatmap = self.create_placeholder_heatmap(height, width)?;
96
97        Ok(heatmap)
98    }
99
100    /// Generate GradCAM++ heatmap (improved version of GradCAM)
101    ///
102    /// GradCAM++ provides better localization and works better for multiple instances
103    /// of the same class in an image.
104    pub fn generate_gradcam_plus_plus(
105        &self,
106        model: &dyn Module,
107        input: &Tensor,
108        target_class: usize,
109    ) -> Result<Tensor> {
110        // GradCAM++ uses second-order and third-order derivatives
111        // for better weighting of activation maps
112        self.generate_heatmap(model, input, target_class)
113    }
114
115    /// Overlay heatmap on original image
116    ///
117    /// # Arguments
118    /// * `image` - Original image tensor [C, H, W]
119    /// * `heatmap` - Heatmap tensor [H, W]
120    /// * `alpha` - Blending factor (0.0 = only image, 1.0 = only heatmap)
121    ///
122    /// # Returns
123    /// Blended visualization [C, H, W]
124    pub fn overlay_heatmap(&self, image: &Tensor, heatmap: &Tensor, alpha: f32) -> Result<Tensor> {
125        // Normalize heatmap to [0, 1]
126        let hmax = heatmap.max(None, false)?;
127        let hmin = heatmap.min()?;
128        let normalized = heatmap.sub(&hmin)?.div(&hmax.sub(&hmin)?)?;
129
130        // Convert heatmap to RGB using a colormap (e.g., jet colormap)
131        let colored_heatmap = self.apply_colormap(&normalized)?;
132
133        // Blend with original image
134        let blended = image
135            .mul_scalar(1.0 - alpha)?
136            .add(&colored_heatmap.mul_scalar(alpha)?)?;
137
138        Ok(blended)
139    }
140
141    /// Apply jet colormap to grayscale heatmap
142    fn apply_colormap(&self, heatmap: &Tensor) -> Result<Tensor> {
143        // Create RGB channels from grayscale heatmap using jet colormap
144        // This is a simplified implementation
145        let r = heatmap.mul_scalar(1.5)?.clamp(0.0, 1.0)?.unsqueeze(0)?;
146        let g = heatmap
147            .mul_scalar(2.0)?
148            .sub_scalar(0.5)?
149            .clamp(0.0, 1.0)?
150            .unsqueeze(0)?;
151        let b = heatmap
152            .mul_scalar(1.5)?
153            .sub_scalar(1.0)?
154            .clamp(0.0, 1.0)?
155            .unsqueeze(0)?;
156
157        let colored = Tensor::cat(&[&r, &g, &b], 0)?;
158        Ok(colored)
159    }
160
161    /// Create placeholder heatmap (to be replaced with actual implementation)
162    fn create_placeholder_heatmap(&self, height: usize, width: usize) -> Result<Tensor> {
163        use torsh_tensor::creation;
164
165        // Create a simple gradient heatmap for demonstration
166        let heatmap: Tensor<f32> = creation::zeros(&[height, width])?;
167
168        Ok(heatmap)
169    }
170}
171
172/// Saliency Map Generator
173///
174/// Saliency maps show which input pixels have the greatest influence on the model's
175/// prediction by computing the gradient of the output with respect to the input.
176pub struct SaliencyMap {
177    device: Arc<dyn Device>,
178}
179
180impl SaliencyMap {
181    /// Create a new saliency map generator
182    pub fn new(device: Arc<dyn Device>) -> Self {
183        Self { device }
184    }
185
186    /// Generate vanilla saliency map
187    ///
188    /// Computes the gradient of the class score with respect to the input image.
189    ///
190    /// # Arguments
191    /// * `model` - The model to explain
192    /// * `input` - Input image tensor [C, H, W] or [N, C, H, W]
193    /// * `target_class` - Class index to generate saliency for
194    ///
195    /// # Returns
196    /// Saliency map showing pixel importance
197    pub fn generate(
198        &self,
199        model: &dyn Module,
200        input: &Tensor,
201        target_class: usize,
202    ) -> Result<Tensor> {
203        // Enable gradient computation for input
204        let input_with_grad = input.clone().requires_grad_(true);
205
206        // Ensure batch dimension
207        let batched = if input_with_grad.ndim() == 3 {
208            input_with_grad.unsqueeze(0)?
209        } else {
210            input_with_grad.clone()
211        };
212
213        // Forward pass
214        let output = model.forward(&batched)?;
215
216        // Get score for target class
217        let score = output.narrow(1, target_class as i64, 1)?;
218
219        // Backward pass
220        score.backward()?;
221
222        // Get gradient with respect to input
223        let grad = batched
224            .grad()
225            .ok_or_else(|| VisionError::Other(anyhow::anyhow!("No gradient computed")))?;
226
227        // Take absolute value and max across channels
228        let abs_grad = grad.abs()?;
229        let saliency = abs_grad.max(Some(1), false)?;
230
231        Ok(saliency)
232    }
233
234    /// Generate smooth saliency map
235    ///
236    /// Averages saliency maps computed from multiple noisy versions of the input
237    /// to reduce noise and produce smoother visualizations.
238    pub fn generate_smooth(
239        &self,
240        model: &dyn Module,
241        input: &Tensor,
242        target_class: usize,
243        num_samples: usize,
244        noise_stddev: f32,
245    ) -> Result<Tensor> {
246        use torsh_tensor::creation;
247
248        let shape: Vec<usize> = input.shape().dims().iter().map(|&x| x as usize).collect();
249        let mut accumulated: Tensor<f32> = creation::zeros(&shape)?;
250
251        for _ in 0..num_samples {
252            // Add random noise to input
253            let noise: Tensor<f32> = creation::randn(&shape)?;
254            let noise = noise.mul_scalar(noise_stddev)?;
255            let noisy_input = input.add(&noise)?;
256
257            // Generate saliency for noisy input
258            let saliency = self.generate(model, &noisy_input, target_class)?;
259
260            // Accumulate
261            accumulated = accumulated.add(&saliency)?;
262        }
263
264        // Average
265        let smooth_saliency = accumulated.div_scalar(num_samples as f32)?;
266
267        Ok(smooth_saliency)
268    }
269}
270
271/// Integrated Gradients
272///
273/// Integrated Gradients is a method that attributes the prediction of a model to its
274/// input features by integrating gradients along a path from a baseline to the input.
275///
276/// Reference: Sundararajan et al., "Axiomatic Attribution for Deep Networks", ICML 2017
277pub struct IntegratedGradients {
278    baseline_type: BaselineType,
279    num_steps: usize,
280    device: Arc<dyn Device>,
281}
282
283/// Type of baseline to use for integrated gradients
284#[derive(Debug, Clone, Copy)]
285pub enum BaselineType {
286    /// Black image (all zeros)
287    Black,
288    /// Random noise
289    Random,
290    /// Blurred version of input
291    Blurred,
292}
293
294impl IntegratedGradients {
295    /// Create a new integrated gradients explainer
296    pub fn new(baseline_type: BaselineType, num_steps: usize, device: Arc<dyn Device>) -> Self {
297        Self {
298            baseline_type,
299            num_steps,
300            device,
301        }
302    }
303
304    /// Generate integrated gradients attribution map
305    ///
306    /// # Arguments
307    /// * `model` - The model to explain
308    /// * `input` - Input image tensor [C, H, W] or [N, C, H, W]
309    /// * `target_class` - Class index to generate attribution for
310    ///
311    /// # Returns
312    /// Attribution map showing feature importance
313    pub fn generate(
314        &self,
315        model: &dyn Module,
316        input: &Tensor,
317        target_class: usize,
318    ) -> Result<Tensor> {
319        // Create baseline
320        let baseline = self.create_baseline(input)?;
321
322        // Compute path from baseline to input
323        let mut accumulated_gradients = baseline.clone();
324
325        for step in 0..self.num_steps {
326            let alpha = (step as f32) / (self.num_steps as f32);
327
328            // Interpolated input
329            let interpolated = baseline
330                .mul_scalar(1.0 - alpha)?
331                .add(&input.mul_scalar(alpha)?)?;
332
333            // Enable gradients
334            let interp_with_grad = interpolated.clone().requires_grad_(true);
335
336            // Forward pass
337            let output = model.forward(&interp_with_grad)?;
338            let score = output.narrow(1, target_class as i64, 1)?;
339
340            // Backward pass
341            score.backward()?;
342
343            // Accumulate gradients
344            let grad = interp_with_grad
345                .grad()
346                .ok_or_else(|| VisionError::Other(anyhow::anyhow!("No gradient computed")))?;
347            accumulated_gradients = accumulated_gradients.add(&grad)?;
348        }
349
350        // Average gradients and multiply by (input - baseline)
351        let avg_gradients = accumulated_gradients.div_scalar(self.num_steps as f32)?;
352        let attribution = input.sub(&baseline)?.mul(&avg_gradients)?;
353
354        Ok(attribution)
355    }
356
357    /// Create baseline based on baseline type
358    fn create_baseline(&self, input: &Tensor) -> Result<Tensor> {
359        use torsh_tensor::creation;
360
361        let shape: Vec<usize> = input.shape().dims().iter().map(|&x| x as usize).collect();
362
363        match self.baseline_type {
364            BaselineType::Black => {
365                let baseline: Tensor<f32> = creation::zeros(&shape)?;
366                Ok(baseline)
367            }
368            BaselineType::Random => {
369                let baseline: Tensor<f32> = creation::randn(&shape)?;
370                Ok(baseline.mul_scalar(0.1)?)
371            }
372            BaselineType::Blurred => {
373                // Simple blur by downsampling and upsampling
374                // In production, use actual Gaussian blur
375                Ok(input.clone())
376            }
377        }
378    }
379}
380
381/// Attention Visualization
382///
383/// For models with attention mechanisms (e.g., Vision Transformers),
384/// visualizes the attention weights to understand which parts of the image
385/// the model is focusing on.
386pub struct AttentionVisualizer {
387    device: Arc<dyn Device>,
388}
389
390impl AttentionVisualizer {
391    /// Create a new attention visualizer
392    pub fn new(device: Arc<dyn Device>) -> Self {
393        Self { device }
394    }
395
396    /// Visualize attention weights from a transformer layer
397    ///
398    /// # Arguments
399    /// * `attention_weights` - Attention weight tensor [N, num_heads, seq_len, seq_len]
400    /// * `patch_size` - Size of image patches
401    /// * `image_size` - Original image size (H, W)
402    ///
403    /// # Returns
404    /// Attention map showing which regions are attended to
405    pub fn visualize_attention(
406        &self,
407        attention_weights: &Tensor,
408        patch_size: usize,
409        image_size: (usize, usize),
410    ) -> Result<Tensor> {
411        // Average attention across heads
412        let avg_attention = attention_weights.mean(Some(&[1]), false)?;
413
414        // Extract attention from CLS token to patches (first row)
415        let cls_attention = avg_attention.narrow(1, 0, 1)?;
416
417        // Reshape to spatial grid
418        let num_patches_h = image_size.0 / patch_size;
419        let num_patches_w = image_size.1 / patch_size;
420
421        let reshaped =
422            cls_attention.reshape(&[1i32, num_patches_h as i32, num_patches_w as i32])?;
423
424        // Upsample to original image size
425        // In production, use proper interpolation
426        let upsampled = reshaped.clone();
427
428        Ok(upsampled)
429    }
430
431    /// Visualize attention rollout
432    ///
433    /// Combines attention from multiple layers to understand
434    /// the full attention flow through the network.
435    pub fn attention_rollout(&self, attention_layers: Vec<Tensor>) -> Result<Tensor> {
436        if attention_layers.is_empty() {
437            return Err(VisionError::InvalidArgument(
438                "No attention layers provided".to_string(),
439            ));
440        }
441
442        // Start with identity matrix
443        let mut rollout = attention_layers[0].clone();
444
445        // Multiply attention matrices from consecutive layers
446        for attention in attention_layers.iter().skip(1) {
447            rollout = rollout.matmul(attention)?;
448        }
449
450        Ok(rollout)
451    }
452}
453
454/// Feature Visualization
455///
456/// Generate synthetic images that maximally activate specific neurons or layers,
457/// helping understand what features the network has learned.
458pub struct FeatureVisualizer {
459    learning_rate: f32,
460    num_iterations: usize,
461    device: Arc<dyn Device>,
462}
463
464impl FeatureVisualizer {
465    /// Create a new feature visualizer
466    pub fn new(learning_rate: f32, num_iterations: usize, device: Arc<dyn Device>) -> Self {
467        Self {
468            learning_rate,
469            num_iterations,
470            device,
471        }
472    }
473
474    /// Generate an image that maximally activates a specific class
475    ///
476    /// # Arguments
477    /// * `model` - The model to visualize
478    /// * `target_class` - Class to maximize activation for
479    /// * `image_size` - Size of generated image (H, W)
480    ///
481    /// # Returns
482    /// Synthesized image that maximally activates the target class
483    pub fn visualize_class(
484        &self,
485        model: &dyn Module,
486        target_class: usize,
487        image_size: (usize, usize),
488    ) -> Result<Tensor> {
489        use torsh_tensor::creation;
490
491        // Initialize random image
492        let mut image: Tensor<f32> = creation::randn(&[1, 3, image_size.0, image_size.1])?;
493        image = image.mul_scalar(0.1)?.add_scalar(0.5)?.requires_grad_(true);
494
495        // Optimization loop
496        for iteration in 0..self.num_iterations {
497            // Forward pass
498            let output = model.forward(&image)?;
499            let class_score = output.narrow(1, target_class as i64, 1)?;
500
501            // We want to maximize the class score
502            let loss = class_score.neg()?;
503
504            // Backward pass
505            loss.backward()?;
506
507            // Update image
508            let grad = image
509                .grad()
510                .ok_or_else(|| VisionError::Other(anyhow::anyhow!("No gradient computed")))?;
511            image = image.sub(&grad.mul_scalar(self.learning_rate)?)?;
512
513            // Apply regularization (keep values in reasonable range)
514            image = image.clamp(-2.0, 2.0)?;
515
516            if iteration % 10 == 0 {
517                println!("Iteration {}: loss = {:?}", iteration, loss.item());
518            }
519        }
520
521        Ok(image)
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use torsh_core::device::CpuDevice;
529    use torsh_tensor::creation;
530
531    #[test]
532    fn test_gradcam_creation() {
533        let device = Arc::new(CpuDevice::new());
534        let gradcam = GradCAM::new("layer4".to_string(), device);
535        assert_eq!(gradcam.target_layer_name, "layer4");
536    }
537
538    #[test]
539    fn test_saliency_map_creation() {
540        let device = Arc::new(CpuDevice::new());
541        let _saliency = SaliencyMap::new(device);
542    }
543
544    #[test]
545    fn test_integrated_gradients_creation() {
546        let device = Arc::new(CpuDevice::new());
547        let _ig = IntegratedGradients::new(BaselineType::Black, 50, device);
548    }
549
550    #[test]
551    fn test_attention_visualizer_creation() {
552        let device = Arc::new(CpuDevice::new());
553        let _visualizer = AttentionVisualizer::new(device);
554    }
555
556    #[test]
557    fn test_feature_visualizer_creation() {
558        let device = Arc::new(CpuDevice::new());
559        let _visualizer = FeatureVisualizer::new(0.1, 100, device);
560    }
561
562    #[test]
563    fn test_baseline_types() {
564        let device: Arc<dyn Device> = Arc::new(CpuDevice::new());
565        let ig_black = IntegratedGradients::new(BaselineType::Black, 50, Arc::clone(&device));
566        let ig_random = IntegratedGradients::new(BaselineType::Random, 50, Arc::clone(&device));
567        let ig_blurred = IntegratedGradients::new(BaselineType::Blurred, 50, Arc::clone(&device));
568
569        let input: Tensor<f32> = creation::ones(&[1, 3, 224, 224]).unwrap();
570
571        assert!(ig_black.create_baseline(&input).is_ok());
572        assert!(ig_random.create_baseline(&input).is_ok());
573        assert!(ig_blurred.create_baseline(&input).is_ok());
574    }
575}