sklears_inspection/
computer_vision.rs

1//! Computer Vision Interpretability Methods
2//!
3//! This module provides specialized interpretability methods for computer vision models,
4//! including image-specific LIME, Grad-CAM visualizations, saliency maps, object detection
5//! explanations, and segmentation explanations.
6
7use crate::{types::Float, SklResult, SklearsError};
8// ✅ SciRS2 Policy Compliant Import
9use scirs2_core::ndarray::{Array1, Array2, Array3};
10use scirs2_core::random::Rng;
11use std::collections::HashMap;
12
13/// Configuration for computer vision interpretability methods
14#[derive(Debug, Clone)]
15pub struct ComputerVisionConfig {
16    /// Number of superpixels for LIME segmentation
17    pub n_superpixels: usize,
18    /// Number of perturbations for LIME
19    pub n_perturbations: usize,
20    /// Smoothing parameter for Grad-CAM
21    pub gradcam_smoothing: Float,
22    /// Resolution for saliency maps
23    pub saliency_resolution: (usize, usize),
24    /// Threshold for object detection
25    pub detection_threshold: Float,
26    /// Minimum segment size for segmentation
27    pub min_segment_size: usize,
28    /// Noise level for perturbations
29    pub noise_level: Float,
30}
31
32impl Default for ComputerVisionConfig {
33    fn default() -> Self {
34        Self {
35            n_superpixels: 100,
36            n_perturbations: 1000,
37            gradcam_smoothing: 1.0,
38            saliency_resolution: (224, 224),
39            detection_threshold: 0.5,
40            min_segment_size: 50,
41            noise_level: 0.1,
42        }
43    }
44}
45
46/// Image with metadata for CV explanations
47#[derive(Debug, Clone)]
48pub struct Image {
49    /// Image data (height, width, channels)
50    pub data: Array3<Float>,
51    /// Image width
52    pub width: usize,
53    /// Image height
54    pub height: usize,
55    /// Number of channels
56    pub channels: usize,
57    /// Image format (e.g., "RGB", "BGR", "Grayscale")
58    pub format: String,
59}
60
61/// Superpixel segment
62#[derive(Debug, Clone)]
63pub struct Superpixel {
64    /// Segment identifier
65    pub id: usize,
66    /// Pixel coordinates in the segment
67    pub pixels: Vec<(usize, usize)>,
68    /// Mean color of the segment
69    pub mean_color: Array1<Float>,
70    /// Segment centroid
71    pub centroid: (Float, Float),
72    /// Segment area (number of pixels)
73    pub area: usize,
74}
75
76/// Image LIME explanation result
77#[derive(Debug, Clone)]
78pub struct ImageLimeResult {
79    /// Original image
80    pub image: Image,
81    /// Superpixel segmentation
82    pub superpixels: Vec<Superpixel>,
83    /// Importance scores for each superpixel
84    pub importance_scores: Array1<Float>,
85    /// Explanation mask (same size as image)
86    pub explanation_mask: Array2<Float>,
87    /// Positive and negative contributions
88    pub positive_mask: Array2<Float>,
89    /// negative_mask
90    pub negative_mask: Array2<Float>,
91}
92
93/// Grad-CAM explanation result
94#[derive(Debug, Clone)]
95pub struct GradCAMResult {
96    /// Original image
97    pub image: Image,
98    /// Heatmap showing important regions
99    pub heatmap: Array2<Float>,
100    /// Guided Grad-CAM result
101    pub guided_gradcam: Option<Array3<Float>>,
102    /// Class activation map
103    pub class_activation: Array2<Float>,
104    /// Target class index
105    pub target_class: usize,
106    /// Activation statistics
107    pub activation_stats: GradCAMStats,
108}
109
110/// Grad-CAM statistics
111#[derive(Debug, Clone)]
112pub struct GradCAMStats {
113    /// Maximum activation value
114    pub max_activation: Float,
115    /// Mean activation value
116    pub mean_activation: Float,
117    /// Standard deviation of activations
118    pub std_activation: Float,
119    /// Percentage of image with high activation
120    pub high_activation_percentage: Float,
121}
122
123/// Saliency map result
124#[derive(Debug, Clone)]
125pub struct SaliencyMapResult {
126    /// Original image
127    pub image: Image,
128    /// Saliency map
129    pub saliency_map: Array2<Float>,
130    /// Integrated gradients
131    pub integrated_gradients: Option<Array3<Float>>,
132    /// Smooth gradients
133    pub smooth_gradients: Option<Array3<Float>>,
134    /// Method used for saliency computation
135    pub method: SaliencyMethod,
136}
137
138/// Saliency computation methods
139#[derive(Debug, Clone, PartialEq, Eq)]
140pub enum SaliencyMethod {
141    /// Vanilla gradients
142    Vanilla,
143    /// Integrated gradients
144    IntegratedGradients,
145    /// SmoothGrad
146    SmoothGrad,
147    /// Guided backpropagation
148    GuidedBackprop,
149}
150
151/// Object detection explanation
152#[derive(Debug, Clone)]
153pub struct ObjectDetectionExplanation {
154    /// Original image
155    pub image: Image,
156    /// Detected objects with explanations
157    pub detections: Vec<DetectedObject>,
158    /// Global attention map
159    pub attention_map: Array2<Float>,
160    /// Feature importance for detection
161    pub feature_importance: Array2<Float>,
162}
163
164/// Detected object with explanation
165#[derive(Debug, Clone)]
166pub struct DetectedObject {
167    /// Bounding box (x, y, width, height)
168    pub bbox: (Float, Float, Float, Float),
169    /// Object class
170    pub class: String,
171    /// Confidence score
172    pub confidence: Float,
173    /// Local explanation for this detection
174    pub local_explanation: Array2<Float>,
175    /// Important features for this object
176    pub key_features: Vec<KeyFeature>,
177}
178
179/// Key feature for object detection
180#[derive(Debug, Clone)]
181pub struct KeyFeature {
182    /// Feature name/description
183    pub name: String,
184    /// Feature location (x, y)
185    pub location: (Float, Float),
186    /// Feature importance score
187    pub importance: Float,
188    /// Feature size
189    pub size: (Float, Float),
190}
191
192/// Segmentation explanation
193#[derive(Debug, Clone)]
194pub struct SegmentationExplanation {
195    /// Original image
196    pub image: Image,
197    /// Segmentation mask
198    pub segmentation_mask: Array2<usize>,
199    /// Per-pixel explanations
200    pub pixel_explanations: Array2<Float>,
201    /// Segment-level explanations
202    pub segments: Vec<SegmentExplanation>,
203    /// Boundary explanations
204    pub boundary_importance: Array2<Float>,
205}
206
207/// Explanation for a single segment
208#[derive(Debug, Clone)]
209pub struct SegmentExplanation {
210    /// Segment identifier
211    pub segment_id: usize,
212    /// Segment class/label
213    pub class: String,
214    /// Segment confidence
215    pub confidence: Float,
216    /// Pixels in this segment
217    pub pixels: Vec<(usize, usize)>,
218    /// Segment-level importance
219    pub importance: Float,
220    /// Key visual features
221    pub key_features: Vec<String>,
222}
223
224/// Image-specific LIME explanation
225pub fn explain_image_with_lime<F>(
226    image: &Image,
227    model_fn: F,
228    config: &ComputerVisionConfig,
229) -> SklResult<ImageLimeResult>
230where
231    F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
232{
233    // Generate superpixel segmentation
234    let superpixels = generate_superpixels(image, config)?;
235
236    // Generate perturbations by masking different superpixels
237    let mut perturbations = Vec::new();
238    let mut labels = Vec::new();
239
240    // Original prediction
241    let original_pred = model_fn(&image.data)?;
242
243    // Generate perturbations
244    for _ in 0..config.n_perturbations {
245        let (perturbed_image, active_superpixels) =
246            generate_image_perturbation(image, &superpixels, 0.5)?;
247        let pred = model_fn(&perturbed_image)?;
248
249        perturbations.push(active_superpixels);
250        labels.push(pred);
251    }
252
253    // Compute importance scores
254    let importance_scores = compute_lime_weights_image(&perturbations, &labels, &original_pred)?;
255
256    // Generate explanation masks
257    let (explanation_mask, positive_mask, negative_mask) =
258        generate_explanation_masks(image, &superpixels, &importance_scores)?;
259
260    Ok(ImageLimeResult {
261        image: image.clone(),
262        superpixels,
263        importance_scores,
264        explanation_mask,
265        positive_mask,
266        negative_mask,
267    })
268}
269
270/// Generate Grad-CAM visualization
271pub fn generate_gradcam<F>(
272    image: &Image,
273    model_fn: F,
274    target_class: usize,
275    config: &ComputerVisionConfig,
276) -> SklResult<GradCAMResult>
277where
278    F: Fn(&Array3<Float>) -> SklResult<(Array1<Float>, Array3<Float>)>, // Returns (predictions, feature_maps)
279{
280    // Get model predictions and feature maps
281    let (predictions, feature_maps) = model_fn(&image.data)?;
282
283    if target_class >= predictions.len() {
284        return Err(SklearsError::InvalidInput(
285            "Target class index out of range".to_string(),
286        ));
287    }
288
289    // Compute gradients (simplified - in practice would use automatic differentiation)
290    let gradients = compute_gradients(&predictions, &feature_maps, target_class)?;
291
292    // Generate class activation map
293    let class_activation = compute_class_activation_map(&feature_maps, &gradients)?;
294
295    // Apply smoothing
296    let heatmap = apply_smoothing(&class_activation, config.gradcam_smoothing)?;
297
298    // Compute statistics
299    let activation_stats = compute_gradcam_stats(&heatmap)?;
300
301    Ok(GradCAMResult {
302        image: image.clone(),
303        heatmap,
304        guided_gradcam: None,
305        class_activation,
306        target_class,
307        activation_stats,
308    })
309}
310
311/// Generate saliency map
312pub fn generate_saliency_map<F>(
313    image: &Image,
314    model_fn: F,
315    method: SaliencyMethod,
316    config: &ComputerVisionConfig,
317) -> SklResult<SaliencyMapResult>
318where
319    F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
320{
321    let saliency_map = match method {
322        SaliencyMethod::Vanilla => compute_vanilla_gradients(image, &model_fn)?,
323        SaliencyMethod::IntegratedGradients => {
324            compute_integrated_gradients(image, &model_fn, config)?
325        }
326        SaliencyMethod::SmoothGrad => compute_smooth_gradients(image, &model_fn, config)?,
327        SaliencyMethod::GuidedBackprop => compute_guided_backprop(image, &model_fn)?,
328    };
329
330    Ok(SaliencyMapResult {
331        image: image.clone(),
332        saliency_map,
333        integrated_gradients: None,
334        smooth_gradients: None,
335        method,
336    })
337}
338
339/// Explain object detection
340pub fn explain_object_detection<F>(
341    image: &Image,
342    model_fn: F,
343    config: &ComputerVisionConfig,
344) -> SklResult<ObjectDetectionExplanation>
345where
346    F: Fn(&Array3<Float>) -> SklResult<Vec<(Float, Float, Float, Float, String, Float)>>, // Returns (x, y, w, h, class, confidence)
347{
348    // Get detections
349    let detections_raw = model_fn(&image.data)?;
350
351    // Filter detections by confidence threshold
352    let filtered_detections: Vec<_> = detections_raw
353        .into_iter()
354        .filter(|(_, _, _, _, _, conf)| *conf >= config.detection_threshold)
355        .collect();
356
357    // Generate explanations for each detection
358    let mut detections = Vec::new();
359    for (x, y, w, h, class, confidence) in filtered_detections {
360        let bbox = (x, y, w, h);
361        let local_explanation = generate_detection_explanation(image, &bbox, config)?;
362        let key_features = extract_key_features(image, &bbox, &local_explanation)?;
363
364        detections.push(DetectedObject {
365            bbox,
366            class,
367            confidence,
368            local_explanation,
369            key_features,
370        });
371    }
372
373    // Generate global attention map
374    let attention_map = generate_global_attention_map(image, &detections)?;
375
376    // Compute feature importance
377    let feature_importance = compute_detection_feature_importance(image, &detections)?;
378
379    Ok(ObjectDetectionExplanation {
380        image: image.clone(),
381        detections,
382        attention_map,
383        feature_importance,
384    })
385}
386
387/// Explain image segmentation
388pub fn explain_segmentation<F>(
389    image: &Image,
390    model_fn: F,
391    config: &ComputerVisionConfig,
392) -> SklResult<SegmentationExplanation>
393where
394    F: Fn(&Array3<Float>) -> SklResult<(Array2<usize>, Array1<Float>)>, // Returns (segmentation, confidence)
395{
396    // Get segmentation
397    let (segmentation_mask, confidence) = model_fn(&image.data)?;
398
399    // Generate per-pixel explanations
400    let pixel_explanations = generate_pixel_explanations(image, &segmentation_mask, config)?;
401
402    // Extract segments
403    let segments = extract_segments(&segmentation_mask, &pixel_explanations, &confidence, config)?;
404
405    // Compute boundary importance
406    let boundary_importance = compute_boundary_importance(&segmentation_mask, &pixel_explanations)?;
407
408    Ok(SegmentationExplanation {
409        image: image.clone(),
410        segmentation_mask,
411        pixel_explanations,
412        segments,
413        boundary_importance,
414    })
415}
416
417// Helper functions
418
419/// Generate superpixels using simple grid-based segmentation
420fn generate_superpixels(
421    image: &Image,
422    config: &ComputerVisionConfig,
423) -> SklResult<Vec<Superpixel>> {
424    let mut superpixels = Vec::new();
425
426    // Simple grid-based superpixels (in practice, would use SLIC or similar)
427    let grid_size = (config.n_superpixels as Float).sqrt() as usize;
428    let step_x = image.width / grid_size;
429    let step_y = image.height / grid_size;
430
431    let mut id = 0;
432    for row in 0..grid_size {
433        for col in 0..grid_size {
434            let start_x = col * step_x;
435            let end_x = ((col + 1) * step_x).min(image.width);
436            let start_y = row * step_y;
437            let end_y = ((row + 1) * step_y).min(image.height);
438
439            let mut pixels = Vec::new();
440            let mut color_sum = Array1::zeros(image.channels);
441
442            for y in start_y..end_y {
443                for x in start_x..end_x {
444                    pixels.push((x, y));
445                    if y < image.data.shape()[0] && x < image.data.shape()[1] {
446                        for c in 0..image.channels {
447                            if c < image.data.shape()[2] {
448                                color_sum[c] += image.data[[y, x, c]];
449                            }
450                        }
451                    }
452                }
453            }
454
455            let mean_color = if !pixels.is_empty() {
456                color_sum / pixels.len() as Float
457            } else {
458                Array1::zeros(image.channels)
459            };
460
461            let centroid = (
462                (start_x + end_x) as Float / 2.0,
463                (start_y + end_y) as Float / 2.0,
464            );
465
466            superpixels.push(Superpixel {
467                id,
468                pixels,
469                mean_color,
470                centroid,
471                area: (end_x - start_x) * (end_y - start_y),
472            });
473
474            id += 1;
475        }
476    }
477
478    Ok(superpixels)
479}
480
481/// Generate image perturbation by masking superpixels
482fn generate_image_perturbation(
483    image: &Image,
484    superpixels: &[Superpixel],
485    mask_probability: Float,
486) -> SklResult<(Array3<Float>, Vec<bool>)> {
487    let mut perturbed = image.data.clone();
488    let mut active_superpixels = vec![false; superpixels.len()];
489
490    for (i, superpixel) in superpixels.iter().enumerate() {
491        let is_active = scirs2_core::random::thread_rng().random::<Float>() > mask_probability;
492        active_superpixels[i] = is_active;
493
494        if !is_active {
495            // Mask this superpixel (set to mean color or zero)
496            for &(x, y) in &superpixel.pixels {
497                if y < perturbed.shape()[0] && x < perturbed.shape()[1] {
498                    for c in 0..perturbed.shape()[2] {
499                        perturbed[[y, x, c]] = 0.0; // or superpixel.mean_color[c]
500                    }
501                }
502            }
503        }
504    }
505
506    Ok((perturbed, active_superpixels))
507}
508
509/// Compute LIME weights for image explanation
510fn compute_lime_weights_image(
511    perturbations: &[Vec<bool>],
512    predictions: &[Array1<Float>],
513    original_pred: &Array1<Float>,
514) -> SklResult<Array1<Float>> {
515    if perturbations.is_empty() || predictions.is_empty() {
516        return Err(SklearsError::InvalidInput(
517            "No perturbations provided".to_string(),
518        ));
519    }
520
521    let n_superpixels = perturbations[0].len();
522    let mut weights = Array1::zeros(n_superpixels);
523
524    // Simple correlation-based weights
525    for j in 0..n_superpixels {
526        let mut correlation = 0.0;
527        let mut count = 0;
528
529        for (i, pred) in predictions.iter().enumerate() {
530            if i < perturbations.len() {
531                let feature_active = perturbations[i][j] as usize as Float;
532                let target_change = (pred[0] - original_pred[0]).abs();
533                correlation += feature_active * target_change;
534                count += 1;
535            }
536        }
537
538        if count > 0 {
539            weights[j] = correlation / count as Float;
540        }
541    }
542
543    Ok(weights)
544}
545
546/// Generate explanation masks from superpixels and importance scores
547fn generate_explanation_masks(
548    image: &Image,
549    superpixels: &[Superpixel],
550    importance_scores: &Array1<Float>,
551) -> SklResult<(Array2<Float>, Array2<Float>, Array2<Float>)> {
552    let mut explanation_mask = Array2::zeros((image.height, image.width));
553    let mut positive_mask = Array2::zeros((image.height, image.width));
554    let mut negative_mask = Array2::zeros((image.height, image.width));
555
556    for (i, superpixel) in superpixels.iter().enumerate() {
557        let importance = importance_scores[i];
558
559        for &(x, y) in &superpixel.pixels {
560            if y < explanation_mask.shape()[0] && x < explanation_mask.shape()[1] {
561                explanation_mask[[y, x]] = importance;
562
563                if importance > 0.0 {
564                    positive_mask[[y, x]] = importance;
565                } else {
566                    negative_mask[[y, x]] = importance.abs();
567                }
568            }
569        }
570    }
571
572    Ok((explanation_mask, positive_mask, negative_mask))
573}
574
575/// Compute gradients (simplified implementation)
576fn compute_gradients(
577    predictions: &Array1<Float>,
578    feature_maps: &Array3<Float>,
579    target_class: usize,
580) -> SklResult<Array3<Float>> {
581    // Simplified gradient computation
582    // In practice, this would use automatic differentiation
583    let mut gradients = Array3::zeros(feature_maps.raw_dim());
584
585    if target_class < predictions.len() {
586        let target_score = predictions[target_class];
587
588        // Simple approximation: gradients proportional to feature maps and target score
589        for ((i, j, k), &feature_val) in feature_maps.indexed_iter() {
590            gradients[[i, j, k]] = feature_val * target_score;
591        }
592    }
593
594    Ok(gradients)
595}
596
597/// Compute class activation map
598fn compute_class_activation_map(
599    feature_maps: &Array3<Float>,
600    gradients: &Array3<Float>,
601) -> SklResult<Array2<Float>> {
602    if feature_maps.shape() != gradients.shape() {
603        return Err(SklearsError::InvalidInput(
604            "Feature maps and gradients shape mismatch".to_string(),
605        ));
606    }
607
608    let (height, width, _) = feature_maps.dim();
609    let mut activation_map = Array2::zeros((height, width));
610
611    // Global average pooling of gradients for weights
612    let mut weights = Array1::zeros(feature_maps.shape()[2]);
613    for k in 0..feature_maps.shape()[2] {
614        let mut sum = 0.0;
615        let mut count = 0;
616
617        for i in 0..height {
618            for j in 0..width {
619                sum += gradients[[i, j, k]];
620                count += 1;
621            }
622        }
623
624        weights[k] = if count > 0 { sum / count as Float } else { 0.0 };
625    }
626
627    // Weighted combination of feature maps
628    for i in 0..height {
629        for j in 0..width {
630            let mut activation: Float = 0.0;
631            for k in 0..feature_maps.shape()[2] {
632                activation += weights[k] * feature_maps[[i, j, k]];
633            }
634            activation_map[[i, j]] = activation.max(0.0); // ReLU
635        }
636    }
637
638    Ok(activation_map)
639}
640
641/// Apply smoothing to heatmap
642fn apply_smoothing(heatmap: &Array2<Float>, smoothing: Float) -> SklResult<Array2<Float>> {
643    // Simple Gaussian-like smoothing
644    let mut smoothed = heatmap.clone();
645
646    if smoothing > 0.0 {
647        let (height, width) = heatmap.dim();
648
649        for i in 1..(height - 1) {
650            for j in 1..(width - 1) {
651                let mut sum = 0.0;
652                let mut count = 0;
653
654                // 3x3 kernel
655                for di in -1i32..=1 {
656                    for dj in -1i32..=1 {
657                        let ni = (i as i32 + di) as usize;
658                        let nj = (j as i32 + dj) as usize;
659
660                        if ni < height && nj < width {
661                            sum += heatmap[[ni, nj]];
662                            count += 1;
663                        }
664                    }
665                }
666
667                if count > 0 {
668                    smoothed[[i, j]] = sum / count as Float;
669                }
670            }
671        }
672    }
673
674    Ok(smoothed)
675}
676
677/// Compute Grad-CAM statistics
678fn compute_gradcam_stats(heatmap: &Array2<Float>) -> SklResult<GradCAMStats> {
679    let values: Vec<Float> = heatmap.iter().cloned().collect();
680
681    if values.is_empty() {
682        return Ok(GradCAMStats {
683            max_activation: 0.0,
684            mean_activation: 0.0,
685            std_activation: 0.0,
686            high_activation_percentage: 0.0,
687        });
688    }
689
690    let max_activation = values.iter().cloned().fold(Float::NEG_INFINITY, Float::max);
691    let mean_activation = values.iter().sum::<Float>() / values.len() as Float;
692
693    let variance = values
694        .iter()
695        .map(|&x| (x - mean_activation).powi(2))
696        .sum::<Float>()
697        / values.len() as Float;
698    let std_activation = variance.sqrt();
699
700    let threshold = mean_activation + std_activation;
701    let high_count = values.iter().filter(|&&x| x > threshold).count();
702    let high_activation_percentage = high_count as Float / values.len() as Float;
703
704    Ok(GradCAMStats {
705        max_activation,
706        mean_activation,
707        std_activation,
708        high_activation_percentage,
709    })
710}
711
712/// Compute vanilla gradients
713fn compute_vanilla_gradients<F>(image: &Image, model_fn: F) -> SklResult<Array2<Float>>
714where
715    F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
716{
717    // Simplified gradient computation
718    let (height, width) = (image.height, image.width);
719    let mut gradients = Array2::zeros((height, width));
720
721    let original_pred = model_fn(&image.data)?;
722    let epsilon = 1e-4;
723
724    // Finite difference approximation
725    for i in 0..height {
726        for j in 0..width {
727            let mut perturbed = image.data.clone();
728
729            // Perturb pixel
730            for c in 0..image.channels {
731                if c < perturbed.shape()[2] {
732                    perturbed[[i, j, c]] += epsilon;
733                }
734            }
735
736            let perturbed_pred = model_fn(&perturbed)?;
737            let gradient = (perturbed_pred[0] - original_pred[0]) / epsilon;
738            gradients[[i, j]] = gradient;
739        }
740    }
741
742    Ok(gradients)
743}
744
745/// Compute integrated gradients
746fn compute_integrated_gradients<F>(
747    image: &Image,
748    model_fn: F,
749    config: &ComputerVisionConfig,
750) -> SklResult<Array2<Float>>
751where
752    F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
753{
754    let n_steps = 50;
755    let baseline = Array3::<Float>::zeros(image.data.raw_dim());
756
757    let mut integrated_gradients = Array2::zeros((image.height, image.width));
758
759    for step in 0..n_steps {
760        let alpha = step as Float / n_steps as Float;
761
762        // Interpolate between baseline and input
763        let interpolated = &baseline + alpha * (&image.data - &baseline);
764
765        // Compute gradients at this point
766        let gradients = compute_vanilla_gradients(
767            &Image {
768                data: interpolated,
769                width: image.width,
770                height: image.height,
771                channels: image.channels,
772                format: image.format.clone(),
773            },
774            &model_fn,
775        )?;
776
777        integrated_gradients = integrated_gradients + gradients;
778    }
779
780    // Average and multiply by (input - baseline)
781    integrated_gradients /= n_steps as Float;
782
783    Ok(integrated_gradients)
784}
785
786/// Compute smooth gradients
787fn compute_smooth_gradients<F>(
788    image: &Image,
789    model_fn: F,
790    config: &ComputerVisionConfig,
791) -> SklResult<Array2<Float>>
792where
793    F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
794{
795    let n_samples = 50;
796    let mut smooth_gradients = Array2::zeros((image.height, image.width));
797
798    for _ in 0..n_samples {
799        // Add noise to image
800        let mut noisy_image = image.data.clone();
801        for elem in noisy_image.iter_mut() {
802            *elem +=
803                (scirs2_core::random::thread_rng().random::<Float>() - 0.5) * config.noise_level;
804        }
805
806        // Compute gradients
807        let gradients = compute_vanilla_gradients(
808            &Image {
809                data: noisy_image,
810                width: image.width,
811                height: image.height,
812                channels: image.channels,
813                format: image.format.clone(),
814            },
815            &model_fn,
816        )?;
817
818        smooth_gradients = smooth_gradients + gradients;
819    }
820
821    smooth_gradients /= n_samples as Float;
822    Ok(smooth_gradients)
823}
824
825/// Compute guided backpropagation
826fn compute_guided_backprop<F>(image: &Image, model_fn: F) -> SklResult<Array2<Float>>
827where
828    F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
829{
830    // Simplified - in practice would need access to intermediate activations
831    compute_vanilla_gradients(image, model_fn)
832}
833
834/// Generate detection explanation for a bounding box
835fn generate_detection_explanation(
836    image: &Image,
837    bbox: &(Float, Float, Float, Float),
838    config: &ComputerVisionConfig,
839) -> SklResult<Array2<Float>> {
840    let (x, y, w, h) = *bbox;
841    let mut explanation = Array2::zeros((image.height, image.width));
842
843    // Simple box-based explanation
844    let x_start = (x as usize).min(image.width);
845    let y_start = (y as usize).min(image.height);
846    let x_end = ((x + w) as usize).min(image.width);
847    let y_end = ((y + h) as usize).min(image.height);
848
849    for i in y_start..y_end {
850        for j in x_start..x_end {
851            explanation[[i, j]] = 1.0;
852        }
853    }
854
855    Ok(explanation)
856}
857
858/// Extract key features from bounding box region
859fn extract_key_features(
860    image: &Image,
861    bbox: &(Float, Float, Float, Float),
862    explanation: &Array2<Float>,
863) -> SklResult<Vec<KeyFeature>> {
864    let mut features = Vec::new();
865
866    let (x, y, w, h) = *bbox;
867
868    // Simple feature: center of bounding box
869    features.push(KeyFeature {
870        name: "Center".to_string(),
871        location: (x + w / 2.0, y + h / 2.0),
872        importance: 1.0,
873        size: (w, h),
874    });
875
876    // Add corners as features
877    features.push(KeyFeature {
878        name: "Top-left corner".to_string(),
879        location: (x, y),
880        importance: 0.8,
881        size: (w * 0.1, h * 0.1),
882    });
883
884    features.push(KeyFeature {
885        name: "Bottom-right corner".to_string(),
886        location: (x + w, y + h),
887        importance: 0.8,
888        size: (w * 0.1, h * 0.1),
889    });
890
891    Ok(features)
892}
893
894/// Generate global attention map from all detections
895fn generate_global_attention_map(
896    image: &Image,
897    detections: &[DetectedObject],
898) -> SklResult<Array2<Float>> {
899    let mut attention_map = Array2::zeros((image.height, image.width));
900
901    for detection in detections {
902        // Add local explanation weighted by confidence
903        for i in 0..image.height {
904            for j in 0..image.width {
905                if i < detection.local_explanation.shape()[0]
906                    && j < detection.local_explanation.shape()[1]
907                {
908                    attention_map[[i, j]] +=
909                        detection.local_explanation[[i, j]] * detection.confidence;
910                }
911            }
912        }
913    }
914
915    Ok(attention_map)
916}
917
918/// Compute feature importance for object detection
919fn compute_detection_feature_importance(
920    image: &Image,
921    detections: &[DetectedObject],
922) -> SklResult<Array2<Float>> {
923    let mut feature_importance = Array2::zeros((image.height, image.width));
924
925    // Compute importance based on overlap with detection regions
926    for detection in detections {
927        let (x, y, w, h) = detection.bbox;
928        let x_start = (x as usize).min(image.width);
929        let y_start = (y as usize).min(image.height);
930        let x_end = ((x + w) as usize).min(image.width);
931        let y_end = ((y + h) as usize).min(image.height);
932
933        for i in y_start..y_end {
934            for j in x_start..x_end {
935                feature_importance[[i, j]] += detection.confidence;
936            }
937        }
938    }
939
940    Ok(feature_importance)
941}
942
943/// Generate per-pixel explanations for segmentation
944fn generate_pixel_explanations(
945    image: &Image,
946    segmentation: &Array2<usize>,
947    config: &ComputerVisionConfig,
948) -> SklResult<Array2<Float>> {
949    let mut explanations = Array2::zeros((image.height, image.width));
950
951    // Simple explanation based on segment membership
952    for i in 0..image.height {
953        for j in 0..image.width {
954            if i < segmentation.shape()[0] && j < segmentation.shape()[1] {
955                let segment_id = segmentation[[i, j]];
956                explanations[[i, j]] = (segment_id as Float + 1.0) / 256.0; // Normalize
957            }
958        }
959    }
960
961    Ok(explanations)
962}
963
964/// Extract segments from segmentation mask
965fn extract_segments(
966    segmentation: &Array2<usize>,
967    pixel_explanations: &Array2<Float>,
968    confidence: &Array1<Float>,
969    config: &ComputerVisionConfig,
970) -> SklResult<Vec<SegmentExplanation>> {
971    let mut segments = Vec::new();
972    let mut segment_pixels: HashMap<usize, Vec<(usize, usize)>> = HashMap::new();
973
974    // Group pixels by segment
975    for ((i, j), &segment_id) in segmentation.indexed_iter() {
976        segment_pixels.entry(segment_id).or_default().push((i, j));
977    }
978
979    // Create segment explanations
980    for (segment_id, pixels) in segment_pixels {
981        if pixels.len() >= config.min_segment_size {
982            let importance = pixels
983                .iter()
984                .map(|&(i, j)| {
985                    if i < pixel_explanations.shape()[0] && j < pixel_explanations.shape()[1] {
986                        pixel_explanations[[i, j]]
987                    } else {
988                        0.0
989                    }
990                })
991                .sum::<Float>()
992                / pixels.len() as Float;
993
994            let segment_confidence = if segment_id < confidence.len() {
995                confidence[segment_id]
996            } else {
997                0.5
998            };
999
1000            segments.push(SegmentExplanation {
1001                segment_id,
1002                class: format!("class_{}", segment_id),
1003                confidence: segment_confidence,
1004                pixels,
1005                importance,
1006                key_features: vec!["color".to_string(), "texture".to_string()],
1007            });
1008        }
1009    }
1010
1011    Ok(segments)
1012}
1013
1014/// Compute boundary importance for segmentation
1015fn compute_boundary_importance(
1016    segmentation: &Array2<usize>,
1017    pixel_explanations: &Array2<Float>,
1018) -> SklResult<Array2<Float>> {
1019    let (height, width) = segmentation.dim();
1020    let mut boundary_importance = Array2::zeros((height, width));
1021
1022    // Detect boundaries using simple edge detection
1023    for i in 1..(height - 1) {
1024        for j in 1..(width - 1) {
1025            let center_segment = segmentation[[i, j]];
1026            let mut is_boundary = false;
1027
1028            // Check 8-connected neighbors
1029            for di in -1i32..=1 {
1030                for dj in -1i32..=1 {
1031                    if di == 0 && dj == 0 {
1032                        continue;
1033                    }
1034
1035                    let ni = (i as i32 + di) as usize;
1036                    let nj = (j as i32 + dj) as usize;
1037
1038                    if ni < height && nj < width && segmentation[[ni, nj]] != center_segment {
1039                        is_boundary = true;
1040                        break;
1041                    }
1042                }
1043                if is_boundary {
1044                    break;
1045                }
1046            }
1047
1048            if is_boundary {
1049                boundary_importance[[i, j]] = pixel_explanations[[i, j]];
1050            }
1051        }
1052    }
1053
1054    Ok(boundary_importance)
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059    use super::*;
1060    // ✅ SciRS2 Policy Compliant Import
1061    use scirs2_core::ndarray::array;
1062
1063    fn create_test_image() -> Image {
1064        Image {
1065            data: Array3::zeros((4, 4, 3)),
1066            width: 4,
1067            height: 4,
1068            channels: 3,
1069            format: "RGB".to_string(),
1070        }
1071    }
1072
1073    #[test]
1074    fn test_computer_vision_config_default() {
1075        let config = ComputerVisionConfig::default();
1076
1077        assert_eq!(config.n_superpixels, 100);
1078        assert_eq!(config.n_perturbations, 1000);
1079        assert_eq!(config.gradcam_smoothing, 1.0);
1080        assert_eq!(config.saliency_resolution, (224, 224));
1081        assert_eq!(config.detection_threshold, 0.5);
1082        assert_eq!(config.min_segment_size, 50);
1083        assert_eq!(config.noise_level, 0.1);
1084    }
1085
1086    #[test]
1087    fn test_superpixel_generation() {
1088        let image = create_test_image();
1089        let config = ComputerVisionConfig {
1090            n_superpixels: 4,
1091            ..Default::default()
1092        };
1093
1094        let superpixels = generate_superpixels(&image, &config).unwrap();
1095
1096        assert_eq!(superpixels.len(), 4);
1097        for superpixel in &superpixels {
1098            assert!(!superpixel.pixels.is_empty());
1099            assert_eq!(superpixel.mean_color.len(), 3);
1100        }
1101    }
1102
1103    #[test]
1104    fn test_image_perturbation() {
1105        let image = create_test_image();
1106        let config = ComputerVisionConfig {
1107            n_superpixels: 4,
1108            ..Default::default()
1109        };
1110        let superpixels = generate_superpixels(&image, &config).unwrap();
1111
1112        let (perturbed, active) = generate_image_perturbation(&image, &superpixels, 0.5).unwrap();
1113
1114        assert_eq!(perturbed.shape(), image.data.shape());
1115        assert_eq!(active.len(), superpixels.len());
1116    }
1117
1118    #[test]
1119    fn test_gradcam_stats() {
1120        let heatmap = array![[0.1, 0.8], [0.3, 0.9]];
1121        let stats = compute_gradcam_stats(&heatmap).unwrap();
1122
1123        assert_eq!(stats.max_activation, 0.9);
1124        assert!(stats.mean_activation > 0.0);
1125        assert!(stats.std_activation > 0.0);
1126        assert!(stats.high_activation_percentage >= 0.0);
1127        assert!(stats.high_activation_percentage <= 1.0);
1128    }
1129
1130    #[test]
1131    fn test_smoothing() {
1132        let heatmap = array![[1.0, 0.0], [0.0, 1.0]];
1133        let smoothed = apply_smoothing(&heatmap, 1.0).unwrap();
1134
1135        assert_eq!(smoothed.shape(), heatmap.shape());
1136        // Smoothed values should be different from original (except edges)
1137    }
1138
1139    #[test]
1140    fn test_class_activation_map() {
1141        let feature_maps = Array3::ones((2, 2, 3));
1142        let gradients = Array3::ones((2, 2, 3));
1143
1144        let activation_map = compute_class_activation_map(&feature_maps, &gradients).unwrap();
1145
1146        assert_eq!(activation_map.shape(), &[2, 2]);
1147        // All values should be positive due to ReLU
1148        for &val in activation_map.iter() {
1149            assert!(val >= 0.0);
1150        }
1151    }
1152
1153    #[test]
1154    fn test_explanation_mask_generation() {
1155        let image = create_test_image();
1156        let config = ComputerVisionConfig {
1157            n_superpixels: 4,
1158            ..Default::default()
1159        };
1160        let superpixels = generate_superpixels(&image, &config).unwrap();
1161        let importance_scores = Array1::from_vec(vec![0.5, -0.3, 0.8, -0.1]);
1162
1163        let (explanation_mask, positive_mask, negative_mask) =
1164            generate_explanation_masks(&image, &superpixels, &importance_scores).unwrap();
1165
1166        assert_eq!(explanation_mask.shape(), &[4, 4]);
1167        assert_eq!(positive_mask.shape(), &[4, 4]);
1168        assert_eq!(negative_mask.shape(), &[4, 4]);
1169
1170        // Check that positive and negative masks are properly separated
1171        for i in 0..4 {
1172            for j in 0..4 {
1173                if positive_mask[[i, j]] > 0.0 {
1174                    assert_eq!(negative_mask[[i, j]], 0.0);
1175                }
1176                if negative_mask[[i, j]] > 0.0 {
1177                    assert_eq!(positive_mask[[i, j]], 0.0);
1178                }
1179            }
1180        }
1181    }
1182
1183    #[test]
1184    fn test_key_feature_extraction() {
1185        let image = create_test_image();
1186        let bbox = (1.0, 1.0, 2.0, 2.0);
1187        let explanation = Array2::ones((4, 4));
1188
1189        let features = extract_key_features(&image, &bbox, &explanation).unwrap();
1190
1191        assert!(!features.is_empty());
1192        // Should have at least center feature
1193        assert!(features.iter().any(|f| f.name == "Center"));
1194    }
1195
1196    #[test]
1197    fn test_boundary_importance_computation() {
1198        let segmentation = array![[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]];
1199        let pixel_explanations = Array2::ones((4, 4));
1200
1201        let boundary_importance =
1202            compute_boundary_importance(&segmentation, &pixel_explanations).unwrap();
1203
1204        assert_eq!(boundary_importance.shape(), &[4, 4]);
1205        // Boundary pixels should have non-zero importance
1206        assert!(boundary_importance.sum() > 0.0);
1207    }
1208
1209    #[test]
1210    fn test_image_lime_explanation() {
1211        let image = create_test_image();
1212        let config = ComputerVisionConfig {
1213            n_superpixels: 4,
1214            n_perturbations: 10,
1215            ..Default::default()
1216        };
1217
1218        // Mock model function
1219        let model_fn = |image_data: &Array3<Float>| -> SklResult<Array1<Float>> {
1220            Ok(array![image_data.sum()])
1221        };
1222
1223        let result = explain_image_with_lime(&image, model_fn, &config).unwrap();
1224
1225        assert_eq!(result.superpixels.len(), 4);
1226        assert_eq!(result.importance_scores.len(), 4);
1227        assert_eq!(result.explanation_mask.shape(), &[4, 4]);
1228        assert_eq!(result.positive_mask.shape(), &[4, 4]);
1229        assert_eq!(result.negative_mask.shape(), &[4, 4]);
1230    }
1231
1232    #[test]
1233    fn test_saliency_method_variants() {
1234        use SaliencyMethod::*;
1235
1236        let methods = vec![Vanilla, IntegratedGradients, SmoothGrad, GuidedBackprop];
1237
1238        // Test that each variant exists and can be compared
1239        assert_eq!(methods.len(), 4);
1240
1241        for method in methods {
1242            match method {
1243                Vanilla => assert_eq!(method, Vanilla),
1244                IntegratedGradients => assert_eq!(method, IntegratedGradients),
1245                SmoothGrad => assert_eq!(method, SmoothGrad),
1246                GuidedBackprop => assert_eq!(method, GuidedBackprop),
1247            }
1248        }
1249
1250        // Test inequality
1251        assert_ne!(Vanilla, IntegratedGradients);
1252        assert_ne!(SmoothGrad, GuidedBackprop);
1253    }
1254}