1use crate::{types::Float, SklResult, SklearsError};
8use scirs2_core::ndarray::{Array1, Array2, Array3};
10use scirs2_core::random::Rng;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct ComputerVisionConfig {
16 pub n_superpixels: usize,
18 pub n_perturbations: usize,
20 pub gradcam_smoothing: Float,
22 pub saliency_resolution: (usize, usize),
24 pub detection_threshold: Float,
26 pub min_segment_size: usize,
28 pub noise_level: Float,
30}
31
32impl Default for ComputerVisionConfig {
33 fn default() -> Self {
34 Self {
35 n_superpixels: 100,
36 n_perturbations: 1000,
37 gradcam_smoothing: 1.0,
38 saliency_resolution: (224, 224),
39 detection_threshold: 0.5,
40 min_segment_size: 50,
41 noise_level: 0.1,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct Image {
49 pub data: Array3<Float>,
51 pub width: usize,
53 pub height: usize,
55 pub channels: usize,
57 pub format: String,
59}
60
61#[derive(Debug, Clone)]
63pub struct Superpixel {
64 pub id: usize,
66 pub pixels: Vec<(usize, usize)>,
68 pub mean_color: Array1<Float>,
70 pub centroid: (Float, Float),
72 pub area: usize,
74}
75
76#[derive(Debug, Clone)]
78pub struct ImageLimeResult {
79 pub image: Image,
81 pub superpixels: Vec<Superpixel>,
83 pub importance_scores: Array1<Float>,
85 pub explanation_mask: Array2<Float>,
87 pub positive_mask: Array2<Float>,
89 pub negative_mask: Array2<Float>,
91}
92
93#[derive(Debug, Clone)]
95pub struct GradCAMResult {
96 pub image: Image,
98 pub heatmap: Array2<Float>,
100 pub guided_gradcam: Option<Array3<Float>>,
102 pub class_activation: Array2<Float>,
104 pub target_class: usize,
106 pub activation_stats: GradCAMStats,
108}
109
110#[derive(Debug, Clone)]
112pub struct GradCAMStats {
113 pub max_activation: Float,
115 pub mean_activation: Float,
117 pub std_activation: Float,
119 pub high_activation_percentage: Float,
121}
122
123#[derive(Debug, Clone)]
125pub struct SaliencyMapResult {
126 pub image: Image,
128 pub saliency_map: Array2<Float>,
130 pub integrated_gradients: Option<Array3<Float>>,
132 pub smooth_gradients: Option<Array3<Float>>,
134 pub method: SaliencyMethod,
136}
137
138#[derive(Debug, Clone, PartialEq, Eq)]
140pub enum SaliencyMethod {
141 Vanilla,
143 IntegratedGradients,
145 SmoothGrad,
147 GuidedBackprop,
149}
150
151#[derive(Debug, Clone)]
153pub struct ObjectDetectionExplanation {
154 pub image: Image,
156 pub detections: Vec<DetectedObject>,
158 pub attention_map: Array2<Float>,
160 pub feature_importance: Array2<Float>,
162}
163
164#[derive(Debug, Clone)]
166pub struct DetectedObject {
167 pub bbox: (Float, Float, Float, Float),
169 pub class: String,
171 pub confidence: Float,
173 pub local_explanation: Array2<Float>,
175 pub key_features: Vec<KeyFeature>,
177}
178
179#[derive(Debug, Clone)]
181pub struct KeyFeature {
182 pub name: String,
184 pub location: (Float, Float),
186 pub importance: Float,
188 pub size: (Float, Float),
190}
191
192#[derive(Debug, Clone)]
194pub struct SegmentationExplanation {
195 pub image: Image,
197 pub segmentation_mask: Array2<usize>,
199 pub pixel_explanations: Array2<Float>,
201 pub segments: Vec<SegmentExplanation>,
203 pub boundary_importance: Array2<Float>,
205}
206
207#[derive(Debug, Clone)]
209pub struct SegmentExplanation {
210 pub segment_id: usize,
212 pub class: String,
214 pub confidence: Float,
216 pub pixels: Vec<(usize, usize)>,
218 pub importance: Float,
220 pub key_features: Vec<String>,
222}
223
224pub fn explain_image_with_lime<F>(
226 image: &Image,
227 model_fn: F,
228 config: &ComputerVisionConfig,
229) -> SklResult<ImageLimeResult>
230where
231 F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
232{
233 let superpixels = generate_superpixels(image, config)?;
235
236 let mut perturbations = Vec::new();
238 let mut labels = Vec::new();
239
240 let original_pred = model_fn(&image.data)?;
242
243 for _ in 0..config.n_perturbations {
245 let (perturbed_image, active_superpixels) =
246 generate_image_perturbation(image, &superpixels, 0.5)?;
247 let pred = model_fn(&perturbed_image)?;
248
249 perturbations.push(active_superpixels);
250 labels.push(pred);
251 }
252
253 let importance_scores = compute_lime_weights_image(&perturbations, &labels, &original_pred)?;
255
256 let (explanation_mask, positive_mask, negative_mask) =
258 generate_explanation_masks(image, &superpixels, &importance_scores)?;
259
260 Ok(ImageLimeResult {
261 image: image.clone(),
262 superpixels,
263 importance_scores,
264 explanation_mask,
265 positive_mask,
266 negative_mask,
267 })
268}
269
270pub fn generate_gradcam<F>(
272 image: &Image,
273 model_fn: F,
274 target_class: usize,
275 config: &ComputerVisionConfig,
276) -> SklResult<GradCAMResult>
277where
278 F: Fn(&Array3<Float>) -> SklResult<(Array1<Float>, Array3<Float>)>, {
280 let (predictions, feature_maps) = model_fn(&image.data)?;
282
283 if target_class >= predictions.len() {
284 return Err(SklearsError::InvalidInput(
285 "Target class index out of range".to_string(),
286 ));
287 }
288
289 let gradients = compute_gradients(&predictions, &feature_maps, target_class)?;
291
292 let class_activation = compute_class_activation_map(&feature_maps, &gradients)?;
294
295 let heatmap = apply_smoothing(&class_activation, config.gradcam_smoothing)?;
297
298 let activation_stats = compute_gradcam_stats(&heatmap)?;
300
301 Ok(GradCAMResult {
302 image: image.clone(),
303 heatmap,
304 guided_gradcam: None,
305 class_activation,
306 target_class,
307 activation_stats,
308 })
309}
310
311pub fn generate_saliency_map<F>(
313 image: &Image,
314 model_fn: F,
315 method: SaliencyMethod,
316 config: &ComputerVisionConfig,
317) -> SklResult<SaliencyMapResult>
318where
319 F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
320{
321 let saliency_map = match method {
322 SaliencyMethod::Vanilla => compute_vanilla_gradients(image, &model_fn)?,
323 SaliencyMethod::IntegratedGradients => {
324 compute_integrated_gradients(image, &model_fn, config)?
325 }
326 SaliencyMethod::SmoothGrad => compute_smooth_gradients(image, &model_fn, config)?,
327 SaliencyMethod::GuidedBackprop => compute_guided_backprop(image, &model_fn)?,
328 };
329
330 Ok(SaliencyMapResult {
331 image: image.clone(),
332 saliency_map,
333 integrated_gradients: None,
334 smooth_gradients: None,
335 method,
336 })
337}
338
339pub fn explain_object_detection<F>(
341 image: &Image,
342 model_fn: F,
343 config: &ComputerVisionConfig,
344) -> SklResult<ObjectDetectionExplanation>
345where
346 F: Fn(&Array3<Float>) -> SklResult<Vec<(Float, Float, Float, Float, String, Float)>>, {
348 let detections_raw = model_fn(&image.data)?;
350
351 let filtered_detections: Vec<_> = detections_raw
353 .into_iter()
354 .filter(|(_, _, _, _, _, conf)| *conf >= config.detection_threshold)
355 .collect();
356
357 let mut detections = Vec::new();
359 for (x, y, w, h, class, confidence) in filtered_detections {
360 let bbox = (x, y, w, h);
361 let local_explanation = generate_detection_explanation(image, &bbox, config)?;
362 let key_features = extract_key_features(image, &bbox, &local_explanation)?;
363
364 detections.push(DetectedObject {
365 bbox,
366 class,
367 confidence,
368 local_explanation,
369 key_features,
370 });
371 }
372
373 let attention_map = generate_global_attention_map(image, &detections)?;
375
376 let feature_importance = compute_detection_feature_importance(image, &detections)?;
378
379 Ok(ObjectDetectionExplanation {
380 image: image.clone(),
381 detections,
382 attention_map,
383 feature_importance,
384 })
385}
386
387pub fn explain_segmentation<F>(
389 image: &Image,
390 model_fn: F,
391 config: &ComputerVisionConfig,
392) -> SklResult<SegmentationExplanation>
393where
394 F: Fn(&Array3<Float>) -> SklResult<(Array2<usize>, Array1<Float>)>, {
396 let (segmentation_mask, confidence) = model_fn(&image.data)?;
398
399 let pixel_explanations = generate_pixel_explanations(image, &segmentation_mask, config)?;
401
402 let segments = extract_segments(&segmentation_mask, &pixel_explanations, &confidence, config)?;
404
405 let boundary_importance = compute_boundary_importance(&segmentation_mask, &pixel_explanations)?;
407
408 Ok(SegmentationExplanation {
409 image: image.clone(),
410 segmentation_mask,
411 pixel_explanations,
412 segments,
413 boundary_importance,
414 })
415}
416
417fn generate_superpixels(
421 image: &Image,
422 config: &ComputerVisionConfig,
423) -> SklResult<Vec<Superpixel>> {
424 let mut superpixels = Vec::new();
425
426 let grid_size = (config.n_superpixels as Float).sqrt() as usize;
428 let step_x = image.width / grid_size;
429 let step_y = image.height / grid_size;
430
431 let mut id = 0;
432 for row in 0..grid_size {
433 for col in 0..grid_size {
434 let start_x = col * step_x;
435 let end_x = ((col + 1) * step_x).min(image.width);
436 let start_y = row * step_y;
437 let end_y = ((row + 1) * step_y).min(image.height);
438
439 let mut pixels = Vec::new();
440 let mut color_sum = Array1::zeros(image.channels);
441
442 for y in start_y..end_y {
443 for x in start_x..end_x {
444 pixels.push((x, y));
445 if y < image.data.shape()[0] && x < image.data.shape()[1] {
446 for c in 0..image.channels {
447 if c < image.data.shape()[2] {
448 color_sum[c] += image.data[[y, x, c]];
449 }
450 }
451 }
452 }
453 }
454
455 let mean_color = if !pixels.is_empty() {
456 color_sum / pixels.len() as Float
457 } else {
458 Array1::zeros(image.channels)
459 };
460
461 let centroid = (
462 (start_x + end_x) as Float / 2.0,
463 (start_y + end_y) as Float / 2.0,
464 );
465
466 superpixels.push(Superpixel {
467 id,
468 pixels,
469 mean_color,
470 centroid,
471 area: (end_x - start_x) * (end_y - start_y),
472 });
473
474 id += 1;
475 }
476 }
477
478 Ok(superpixels)
479}
480
481fn generate_image_perturbation(
483 image: &Image,
484 superpixels: &[Superpixel],
485 mask_probability: Float,
486) -> SklResult<(Array3<Float>, Vec<bool>)> {
487 let mut perturbed = image.data.clone();
488 let mut active_superpixels = vec![false; superpixels.len()];
489
490 for (i, superpixel) in superpixels.iter().enumerate() {
491 let is_active = scirs2_core::random::thread_rng().random::<Float>() > mask_probability;
492 active_superpixels[i] = is_active;
493
494 if !is_active {
495 for &(x, y) in &superpixel.pixels {
497 if y < perturbed.shape()[0] && x < perturbed.shape()[1] {
498 for c in 0..perturbed.shape()[2] {
499 perturbed[[y, x, c]] = 0.0; }
501 }
502 }
503 }
504 }
505
506 Ok((perturbed, active_superpixels))
507}
508
509fn compute_lime_weights_image(
511 perturbations: &[Vec<bool>],
512 predictions: &[Array1<Float>],
513 original_pred: &Array1<Float>,
514) -> SklResult<Array1<Float>> {
515 if perturbations.is_empty() || predictions.is_empty() {
516 return Err(SklearsError::InvalidInput(
517 "No perturbations provided".to_string(),
518 ));
519 }
520
521 let n_superpixels = perturbations[0].len();
522 let mut weights = Array1::zeros(n_superpixels);
523
524 for j in 0..n_superpixels {
526 let mut correlation = 0.0;
527 let mut count = 0;
528
529 for (i, pred) in predictions.iter().enumerate() {
530 if i < perturbations.len() {
531 let feature_active = perturbations[i][j] as usize as Float;
532 let target_change = (pred[0] - original_pred[0]).abs();
533 correlation += feature_active * target_change;
534 count += 1;
535 }
536 }
537
538 if count > 0 {
539 weights[j] = correlation / count as Float;
540 }
541 }
542
543 Ok(weights)
544}
545
546fn generate_explanation_masks(
548 image: &Image,
549 superpixels: &[Superpixel],
550 importance_scores: &Array1<Float>,
551) -> SklResult<(Array2<Float>, Array2<Float>, Array2<Float>)> {
552 let mut explanation_mask = Array2::zeros((image.height, image.width));
553 let mut positive_mask = Array2::zeros((image.height, image.width));
554 let mut negative_mask = Array2::zeros((image.height, image.width));
555
556 for (i, superpixel) in superpixels.iter().enumerate() {
557 let importance = importance_scores[i];
558
559 for &(x, y) in &superpixel.pixels {
560 if y < explanation_mask.shape()[0] && x < explanation_mask.shape()[1] {
561 explanation_mask[[y, x]] = importance;
562
563 if importance > 0.0 {
564 positive_mask[[y, x]] = importance;
565 } else {
566 negative_mask[[y, x]] = importance.abs();
567 }
568 }
569 }
570 }
571
572 Ok((explanation_mask, positive_mask, negative_mask))
573}
574
575fn compute_gradients(
577 predictions: &Array1<Float>,
578 feature_maps: &Array3<Float>,
579 target_class: usize,
580) -> SklResult<Array3<Float>> {
581 let mut gradients = Array3::zeros(feature_maps.raw_dim());
584
585 if target_class < predictions.len() {
586 let target_score = predictions[target_class];
587
588 for ((i, j, k), &feature_val) in feature_maps.indexed_iter() {
590 gradients[[i, j, k]] = feature_val * target_score;
591 }
592 }
593
594 Ok(gradients)
595}
596
597fn compute_class_activation_map(
599 feature_maps: &Array3<Float>,
600 gradients: &Array3<Float>,
601) -> SklResult<Array2<Float>> {
602 if feature_maps.shape() != gradients.shape() {
603 return Err(SklearsError::InvalidInput(
604 "Feature maps and gradients shape mismatch".to_string(),
605 ));
606 }
607
608 let (height, width, _) = feature_maps.dim();
609 let mut activation_map = Array2::zeros((height, width));
610
611 let mut weights = Array1::zeros(feature_maps.shape()[2]);
613 for k in 0..feature_maps.shape()[2] {
614 let mut sum = 0.0;
615 let mut count = 0;
616
617 for i in 0..height {
618 for j in 0..width {
619 sum += gradients[[i, j, k]];
620 count += 1;
621 }
622 }
623
624 weights[k] = if count > 0 { sum / count as Float } else { 0.0 };
625 }
626
627 for i in 0..height {
629 for j in 0..width {
630 let mut activation: Float = 0.0;
631 for k in 0..feature_maps.shape()[2] {
632 activation += weights[k] * feature_maps[[i, j, k]];
633 }
634 activation_map[[i, j]] = activation.max(0.0); }
636 }
637
638 Ok(activation_map)
639}
640
641fn apply_smoothing(heatmap: &Array2<Float>, smoothing: Float) -> SklResult<Array2<Float>> {
643 let mut smoothed = heatmap.clone();
645
646 if smoothing > 0.0 {
647 let (height, width) = heatmap.dim();
648
649 for i in 1..(height - 1) {
650 for j in 1..(width - 1) {
651 let mut sum = 0.0;
652 let mut count = 0;
653
654 for di in -1i32..=1 {
656 for dj in -1i32..=1 {
657 let ni = (i as i32 + di) as usize;
658 let nj = (j as i32 + dj) as usize;
659
660 if ni < height && nj < width {
661 sum += heatmap[[ni, nj]];
662 count += 1;
663 }
664 }
665 }
666
667 if count > 0 {
668 smoothed[[i, j]] = sum / count as Float;
669 }
670 }
671 }
672 }
673
674 Ok(smoothed)
675}
676
677fn compute_gradcam_stats(heatmap: &Array2<Float>) -> SklResult<GradCAMStats> {
679 let values: Vec<Float> = heatmap.iter().cloned().collect();
680
681 if values.is_empty() {
682 return Ok(GradCAMStats {
683 max_activation: 0.0,
684 mean_activation: 0.0,
685 std_activation: 0.0,
686 high_activation_percentage: 0.0,
687 });
688 }
689
690 let max_activation = values.iter().cloned().fold(Float::NEG_INFINITY, Float::max);
691 let mean_activation = values.iter().sum::<Float>() / values.len() as Float;
692
693 let variance = values
694 .iter()
695 .map(|&x| (x - mean_activation).powi(2))
696 .sum::<Float>()
697 / values.len() as Float;
698 let std_activation = variance.sqrt();
699
700 let threshold = mean_activation + std_activation;
701 let high_count = values.iter().filter(|&&x| x > threshold).count();
702 let high_activation_percentage = high_count as Float / values.len() as Float;
703
704 Ok(GradCAMStats {
705 max_activation,
706 mean_activation,
707 std_activation,
708 high_activation_percentage,
709 })
710}
711
712fn compute_vanilla_gradients<F>(image: &Image, model_fn: F) -> SklResult<Array2<Float>>
714where
715 F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
716{
717 let (height, width) = (image.height, image.width);
719 let mut gradients = Array2::zeros((height, width));
720
721 let original_pred = model_fn(&image.data)?;
722 let epsilon = 1e-4;
723
724 for i in 0..height {
726 for j in 0..width {
727 let mut perturbed = image.data.clone();
728
729 for c in 0..image.channels {
731 if c < perturbed.shape()[2] {
732 perturbed[[i, j, c]] += epsilon;
733 }
734 }
735
736 let perturbed_pred = model_fn(&perturbed)?;
737 let gradient = (perturbed_pred[0] - original_pred[0]) / epsilon;
738 gradients[[i, j]] = gradient;
739 }
740 }
741
742 Ok(gradients)
743}
744
745fn compute_integrated_gradients<F>(
747 image: &Image,
748 model_fn: F,
749 config: &ComputerVisionConfig,
750) -> SklResult<Array2<Float>>
751where
752 F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
753{
754 let n_steps = 50;
755 let baseline = Array3::<Float>::zeros(image.data.raw_dim());
756
757 let mut integrated_gradients = Array2::zeros((image.height, image.width));
758
759 for step in 0..n_steps {
760 let alpha = step as Float / n_steps as Float;
761
762 let interpolated = &baseline + alpha * (&image.data - &baseline);
764
765 let gradients = compute_vanilla_gradients(
767 &Image {
768 data: interpolated,
769 width: image.width,
770 height: image.height,
771 channels: image.channels,
772 format: image.format.clone(),
773 },
774 &model_fn,
775 )?;
776
777 integrated_gradients = integrated_gradients + gradients;
778 }
779
780 integrated_gradients /= n_steps as Float;
782
783 Ok(integrated_gradients)
784}
785
786fn compute_smooth_gradients<F>(
788 image: &Image,
789 model_fn: F,
790 config: &ComputerVisionConfig,
791) -> SklResult<Array2<Float>>
792where
793 F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
794{
795 let n_samples = 50;
796 let mut smooth_gradients = Array2::zeros((image.height, image.width));
797
798 for _ in 0..n_samples {
799 let mut noisy_image = image.data.clone();
801 for elem in noisy_image.iter_mut() {
802 *elem +=
803 (scirs2_core::random::thread_rng().random::<Float>() - 0.5) * config.noise_level;
804 }
805
806 let gradients = compute_vanilla_gradients(
808 &Image {
809 data: noisy_image,
810 width: image.width,
811 height: image.height,
812 channels: image.channels,
813 format: image.format.clone(),
814 },
815 &model_fn,
816 )?;
817
818 smooth_gradients = smooth_gradients + gradients;
819 }
820
821 smooth_gradients /= n_samples as Float;
822 Ok(smooth_gradients)
823}
824
825fn compute_guided_backprop<F>(image: &Image, model_fn: F) -> SklResult<Array2<Float>>
827where
828 F: Fn(&Array3<Float>) -> SklResult<Array1<Float>>,
829{
830 compute_vanilla_gradients(image, model_fn)
832}
833
834fn generate_detection_explanation(
836 image: &Image,
837 bbox: &(Float, Float, Float, Float),
838 config: &ComputerVisionConfig,
839) -> SklResult<Array2<Float>> {
840 let (x, y, w, h) = *bbox;
841 let mut explanation = Array2::zeros((image.height, image.width));
842
843 let x_start = (x as usize).min(image.width);
845 let y_start = (y as usize).min(image.height);
846 let x_end = ((x + w) as usize).min(image.width);
847 let y_end = ((y + h) as usize).min(image.height);
848
849 for i in y_start..y_end {
850 for j in x_start..x_end {
851 explanation[[i, j]] = 1.0;
852 }
853 }
854
855 Ok(explanation)
856}
857
858fn extract_key_features(
860 image: &Image,
861 bbox: &(Float, Float, Float, Float),
862 explanation: &Array2<Float>,
863) -> SklResult<Vec<KeyFeature>> {
864 let mut features = Vec::new();
865
866 let (x, y, w, h) = *bbox;
867
868 features.push(KeyFeature {
870 name: "Center".to_string(),
871 location: (x + w / 2.0, y + h / 2.0),
872 importance: 1.0,
873 size: (w, h),
874 });
875
876 features.push(KeyFeature {
878 name: "Top-left corner".to_string(),
879 location: (x, y),
880 importance: 0.8,
881 size: (w * 0.1, h * 0.1),
882 });
883
884 features.push(KeyFeature {
885 name: "Bottom-right corner".to_string(),
886 location: (x + w, y + h),
887 importance: 0.8,
888 size: (w * 0.1, h * 0.1),
889 });
890
891 Ok(features)
892}
893
894fn generate_global_attention_map(
896 image: &Image,
897 detections: &[DetectedObject],
898) -> SklResult<Array2<Float>> {
899 let mut attention_map = Array2::zeros((image.height, image.width));
900
901 for detection in detections {
902 for i in 0..image.height {
904 for j in 0..image.width {
905 if i < detection.local_explanation.shape()[0]
906 && j < detection.local_explanation.shape()[1]
907 {
908 attention_map[[i, j]] +=
909 detection.local_explanation[[i, j]] * detection.confidence;
910 }
911 }
912 }
913 }
914
915 Ok(attention_map)
916}
917
918fn compute_detection_feature_importance(
920 image: &Image,
921 detections: &[DetectedObject],
922) -> SklResult<Array2<Float>> {
923 let mut feature_importance = Array2::zeros((image.height, image.width));
924
925 for detection in detections {
927 let (x, y, w, h) = detection.bbox;
928 let x_start = (x as usize).min(image.width);
929 let y_start = (y as usize).min(image.height);
930 let x_end = ((x + w) as usize).min(image.width);
931 let y_end = ((y + h) as usize).min(image.height);
932
933 for i in y_start..y_end {
934 for j in x_start..x_end {
935 feature_importance[[i, j]] += detection.confidence;
936 }
937 }
938 }
939
940 Ok(feature_importance)
941}
942
943fn generate_pixel_explanations(
945 image: &Image,
946 segmentation: &Array2<usize>,
947 config: &ComputerVisionConfig,
948) -> SklResult<Array2<Float>> {
949 let mut explanations = Array2::zeros((image.height, image.width));
950
951 for i in 0..image.height {
953 for j in 0..image.width {
954 if i < segmentation.shape()[0] && j < segmentation.shape()[1] {
955 let segment_id = segmentation[[i, j]];
956 explanations[[i, j]] = (segment_id as Float + 1.0) / 256.0; }
958 }
959 }
960
961 Ok(explanations)
962}
963
964fn extract_segments(
966 segmentation: &Array2<usize>,
967 pixel_explanations: &Array2<Float>,
968 confidence: &Array1<Float>,
969 config: &ComputerVisionConfig,
970) -> SklResult<Vec<SegmentExplanation>> {
971 let mut segments = Vec::new();
972 let mut segment_pixels: HashMap<usize, Vec<(usize, usize)>> = HashMap::new();
973
974 for ((i, j), &segment_id) in segmentation.indexed_iter() {
976 segment_pixels.entry(segment_id).or_default().push((i, j));
977 }
978
979 for (segment_id, pixels) in segment_pixels {
981 if pixels.len() >= config.min_segment_size {
982 let importance = pixels
983 .iter()
984 .map(|&(i, j)| {
985 if i < pixel_explanations.shape()[0] && j < pixel_explanations.shape()[1] {
986 pixel_explanations[[i, j]]
987 } else {
988 0.0
989 }
990 })
991 .sum::<Float>()
992 / pixels.len() as Float;
993
994 let segment_confidence = if segment_id < confidence.len() {
995 confidence[segment_id]
996 } else {
997 0.5
998 };
999
1000 segments.push(SegmentExplanation {
1001 segment_id,
1002 class: format!("class_{}", segment_id),
1003 confidence: segment_confidence,
1004 pixels,
1005 importance,
1006 key_features: vec!["color".to_string(), "texture".to_string()],
1007 });
1008 }
1009 }
1010
1011 Ok(segments)
1012}
1013
1014fn compute_boundary_importance(
1016 segmentation: &Array2<usize>,
1017 pixel_explanations: &Array2<Float>,
1018) -> SklResult<Array2<Float>> {
1019 let (height, width) = segmentation.dim();
1020 let mut boundary_importance = Array2::zeros((height, width));
1021
1022 for i in 1..(height - 1) {
1024 for j in 1..(width - 1) {
1025 let center_segment = segmentation[[i, j]];
1026 let mut is_boundary = false;
1027
1028 for di in -1i32..=1 {
1030 for dj in -1i32..=1 {
1031 if di == 0 && dj == 0 {
1032 continue;
1033 }
1034
1035 let ni = (i as i32 + di) as usize;
1036 let nj = (j as i32 + dj) as usize;
1037
1038 if ni < height && nj < width && segmentation[[ni, nj]] != center_segment {
1039 is_boundary = true;
1040 break;
1041 }
1042 }
1043 if is_boundary {
1044 break;
1045 }
1046 }
1047
1048 if is_boundary {
1049 boundary_importance[[i, j]] = pixel_explanations[[i, j]];
1050 }
1051 }
1052 }
1053
1054 Ok(boundary_importance)
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use super::*;
1060 use scirs2_core::ndarray::array;
1062
1063 fn create_test_image() -> Image {
1064 Image {
1065 data: Array3::zeros((4, 4, 3)),
1066 width: 4,
1067 height: 4,
1068 channels: 3,
1069 format: "RGB".to_string(),
1070 }
1071 }
1072
1073 #[test]
1074 fn test_computer_vision_config_default() {
1075 let config = ComputerVisionConfig::default();
1076
1077 assert_eq!(config.n_superpixels, 100);
1078 assert_eq!(config.n_perturbations, 1000);
1079 assert_eq!(config.gradcam_smoothing, 1.0);
1080 assert_eq!(config.saliency_resolution, (224, 224));
1081 assert_eq!(config.detection_threshold, 0.5);
1082 assert_eq!(config.min_segment_size, 50);
1083 assert_eq!(config.noise_level, 0.1);
1084 }
1085
1086 #[test]
1087 fn test_superpixel_generation() {
1088 let image = create_test_image();
1089 let config = ComputerVisionConfig {
1090 n_superpixels: 4,
1091 ..Default::default()
1092 };
1093
1094 let superpixels = generate_superpixels(&image, &config).unwrap();
1095
1096 assert_eq!(superpixels.len(), 4);
1097 for superpixel in &superpixels {
1098 assert!(!superpixel.pixels.is_empty());
1099 assert_eq!(superpixel.mean_color.len(), 3);
1100 }
1101 }
1102
1103 #[test]
1104 fn test_image_perturbation() {
1105 let image = create_test_image();
1106 let config = ComputerVisionConfig {
1107 n_superpixels: 4,
1108 ..Default::default()
1109 };
1110 let superpixels = generate_superpixels(&image, &config).unwrap();
1111
1112 let (perturbed, active) = generate_image_perturbation(&image, &superpixels, 0.5).unwrap();
1113
1114 assert_eq!(perturbed.shape(), image.data.shape());
1115 assert_eq!(active.len(), superpixels.len());
1116 }
1117
1118 #[test]
1119 fn test_gradcam_stats() {
1120 let heatmap = array![[0.1, 0.8], [0.3, 0.9]];
1121 let stats = compute_gradcam_stats(&heatmap).unwrap();
1122
1123 assert_eq!(stats.max_activation, 0.9);
1124 assert!(stats.mean_activation > 0.0);
1125 assert!(stats.std_activation > 0.0);
1126 assert!(stats.high_activation_percentage >= 0.0);
1127 assert!(stats.high_activation_percentage <= 1.0);
1128 }
1129
1130 #[test]
1131 fn test_smoothing() {
1132 let heatmap = array![[1.0, 0.0], [0.0, 1.0]];
1133 let smoothed = apply_smoothing(&heatmap, 1.0).unwrap();
1134
1135 assert_eq!(smoothed.shape(), heatmap.shape());
1136 }
1138
1139 #[test]
1140 fn test_class_activation_map() {
1141 let feature_maps = Array3::ones((2, 2, 3));
1142 let gradients = Array3::ones((2, 2, 3));
1143
1144 let activation_map = compute_class_activation_map(&feature_maps, &gradients).unwrap();
1145
1146 assert_eq!(activation_map.shape(), &[2, 2]);
1147 for &val in activation_map.iter() {
1149 assert!(val >= 0.0);
1150 }
1151 }
1152
1153 #[test]
1154 fn test_explanation_mask_generation() {
1155 let image = create_test_image();
1156 let config = ComputerVisionConfig {
1157 n_superpixels: 4,
1158 ..Default::default()
1159 };
1160 let superpixels = generate_superpixels(&image, &config).unwrap();
1161 let importance_scores = Array1::from_vec(vec![0.5, -0.3, 0.8, -0.1]);
1162
1163 let (explanation_mask, positive_mask, negative_mask) =
1164 generate_explanation_masks(&image, &superpixels, &importance_scores).unwrap();
1165
1166 assert_eq!(explanation_mask.shape(), &[4, 4]);
1167 assert_eq!(positive_mask.shape(), &[4, 4]);
1168 assert_eq!(negative_mask.shape(), &[4, 4]);
1169
1170 for i in 0..4 {
1172 for j in 0..4 {
1173 if positive_mask[[i, j]] > 0.0 {
1174 assert_eq!(negative_mask[[i, j]], 0.0);
1175 }
1176 if negative_mask[[i, j]] > 0.0 {
1177 assert_eq!(positive_mask[[i, j]], 0.0);
1178 }
1179 }
1180 }
1181 }
1182
1183 #[test]
1184 fn test_key_feature_extraction() {
1185 let image = create_test_image();
1186 let bbox = (1.0, 1.0, 2.0, 2.0);
1187 let explanation = Array2::ones((4, 4));
1188
1189 let features = extract_key_features(&image, &bbox, &explanation).unwrap();
1190
1191 assert!(!features.is_empty());
1192 assert!(features.iter().any(|f| f.name == "Center"));
1194 }
1195
1196 #[test]
1197 fn test_boundary_importance_computation() {
1198 let segmentation = array![[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]];
1199 let pixel_explanations = Array2::ones((4, 4));
1200
1201 let boundary_importance =
1202 compute_boundary_importance(&segmentation, &pixel_explanations).unwrap();
1203
1204 assert_eq!(boundary_importance.shape(), &[4, 4]);
1205 assert!(boundary_importance.sum() > 0.0);
1207 }
1208
1209 #[test]
1210 fn test_image_lime_explanation() {
1211 let image = create_test_image();
1212 let config = ComputerVisionConfig {
1213 n_superpixels: 4,
1214 n_perturbations: 10,
1215 ..Default::default()
1216 };
1217
1218 let model_fn = |image_data: &Array3<Float>| -> SklResult<Array1<Float>> {
1220 Ok(array![image_data.sum()])
1221 };
1222
1223 let result = explain_image_with_lime(&image, model_fn, &config).unwrap();
1224
1225 assert_eq!(result.superpixels.len(), 4);
1226 assert_eq!(result.importance_scores.len(), 4);
1227 assert_eq!(result.explanation_mask.shape(), &[4, 4]);
1228 assert_eq!(result.positive_mask.shape(), &[4, 4]);
1229 assert_eq!(result.negative_mask.shape(), &[4, 4]);
1230 }
1231
1232 #[test]
1233 fn test_saliency_method_variants() {
1234 use SaliencyMethod::*;
1235
1236 let methods = vec![Vanilla, IntegratedGradients, SmoothGrad, GuidedBackprop];
1237
1238 assert_eq!(methods.len(), 4);
1240
1241 for method in methods {
1242 match method {
1243 Vanilla => assert_eq!(method, Vanilla),
1244 IntegratedGradients => assert_eq!(method, IntegratedGradients),
1245 SmoothGrad => assert_eq!(method, SmoothGrad),
1246 GuidedBackprop => assert_eq!(method, GuidedBackprop),
1247 }
1248 }
1249
1250 assert_ne!(Vanilla, IntegratedGradients);
1252 assert_ne!(SmoothGrad, GuidedBackprop);
1253 }
1254}