sklears_inspection/
deep_learning.rs

1//! Deep Learning Interpretability
2//!
3//! This module provides advanced interpretability methods specifically designed for deep neural networks,
4//! including concept activation vectors, neural architecture interpretability, and network dissection.
5
6use crate::{Float, SklResult, SklearsError};
7// ✅ SciRS2 Policy Compliant Import
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3, Axis};
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// Configuration for deep learning interpretability methods
14#[derive(Debug, Clone)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct DeepLearningConfig {
17    /// Target layers for analysis
18    pub target_layers: Vec<String>,
19    /// Number of concepts to extract
20    pub num_concepts: usize,
21    /// Concept activation threshold
22    pub activation_threshold: Float,
23    /// Number of test examples for TCAV
24    pub num_test_examples: usize,
25    /// Random seed for reproducibility
26    pub random_seed: Option<u64>,
27    /// Method for concept discovery
28    pub concept_discovery_method: ConceptDiscoveryMethod,
29}
30
31impl Default for DeepLearningConfig {
32    fn default() -> Self {
33        Self {
34            target_layers: vec!["layer_3".to_string(), "layer_5".to_string()],
35            num_concepts: 20,
36            activation_threshold: 0.5,
37            num_test_examples: 500,
38            random_seed: Some(42),
39            concept_discovery_method: ConceptDiscoveryMethod::ACE,
40        }
41    }
42}
43
44/// Methods for concept discovery
45#[derive(Debug, Clone, Copy)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47pub enum ConceptDiscoveryMethod {
48    /// Automated Concept-based Explanations (ACE)
49    ACE,
50    /// Testing with Concept Activation Vectors (TCAV)
51    TCAV,
52    /// Completeness-aware Concept-based Explanations (CCAV)
53    CCAV,
54    /// Network Dissection
55    NetworkDissection,
56}
57
58/// Concept Activation Vector (CAV) structure
59#[derive(Debug, Clone)]
60#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
61pub struct ConceptActivationVector {
62    /// Concept identifier
63    pub concept_id: String,
64    /// Layer name where the concept is detected
65    pub layer_name: String,
66    /// Direction vector in activation space
67    pub direction_vector: Array1<Float>,
68    /// Concept accuracy (how well it separates concept vs non-concept examples)
69    pub accuracy: Float,
70    /// Statistical significance (p-value)
71    pub p_value: Float,
72    /// Examples that activate this concept
73    pub activating_examples: Vec<usize>,
74}
75
76impl ConceptActivationVector {
77    /// Create a new Concept Activation Vector
78    pub fn new(concept_id: String, layer_name: String, direction_vector: Array1<Float>) -> Self {
79        Self {
80            concept_id,
81            layer_name,
82            direction_vector,
83            accuracy: 0.0,
84            p_value: 1.0,
85            activating_examples: Vec::new(),
86        }
87    }
88
89    /// Compute concept sensitivity for a given input
90    pub fn compute_sensitivity(&self, activation: &ArrayView1<Float>) -> Float {
91        activation.dot(&self.direction_vector)
92    }
93
94    /// Check if an input activates this concept
95    pub fn is_activated(&self, activation: &ArrayView1<Float>, threshold: Float) -> bool {
96        self.compute_sensitivity(activation) > threshold
97    }
98}
99
100/// TCAV (Testing with Concept Activation Vectors) result
101#[derive(Debug, Clone)]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103pub struct TCAVResult {
104    /// TCAV score (sensitivity to concept)
105    pub tcav_score: Float,
106    /// Statistical significance
107    pub p_value: Float,
108    /// Confidence interval
109    pub confidence_interval: (Float, Float),
110    /// Number of inputs in class that activate the concept
111    pub num_activated: usize,
112    /// Total number of inputs tested
113    pub total_inputs: usize,
114    /// Concept activation vector used
115    pub cav: ConceptActivationVector,
116}
117
118impl TCAVResult {
119    /// Check if the TCAV result is statistically significant
120    pub fn is_significant(&self, alpha: Float) -> bool {
121        self.p_value < alpha
122    }
123
124    /// Get effect size interpretation
125    pub fn effect_size_interpretation(&self) -> String {
126        match self.tcav_score {
127            score if score < 0.1 => "Negligible effect".to_string(),
128            score if score < 0.3 => "Small effect".to_string(),
129            score if score < 0.5 => "Medium effect".to_string(),
130            _ => "Large effect".to_string(),
131        }
132    }
133}
134
135/// Network dissection result
136#[derive(Debug, Clone)]
137#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
138pub struct NetworkDissectionResult {
139    /// Layer-wise concept analysis
140    pub layer_concepts: HashMap<String, Vec<DetectedConcept>>,
141    /// Overall network interpretability score
142    pub interpretability_score: Float,
143    /// Concept hierarchy
144    pub concept_hierarchy: ConceptHierarchy,
145    /// Disentanglement metrics
146    pub disentanglement_metrics: DisentanglementMetrics,
147}
148
149/// Detected concept in network dissection
150#[derive(Debug, Clone)]
151#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
152pub struct DetectedConcept {
153    /// Concept name
154    pub name: String,
155    /// Concept category (e.g., "object", "texture", "color")
156    pub category: String,
157    /// IoU score with ground truth concept
158    pub iou_score: Float,
159    /// Units in the layer that detect this concept
160    pub detecting_units: Vec<usize>,
161    /// Threshold for concept activation
162    pub activation_threshold: Float,
163}
164
165/// Concept hierarchy structure
166#[derive(Debug, Clone)]
167#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
168pub struct ConceptHierarchy {
169    /// Hierarchical relationships between concepts
170    pub relationships: HashMap<String, Vec<String>>,
171    /// Concept abstraction levels
172    pub abstraction_levels: HashMap<String, usize>,
173    /// Concept co-occurrence matrix
174    pub co_occurrence: Array2<Float>,
175}
176
177/// Disentanglement metrics for evaluating concept separation
178#[derive(Debug, Clone)]
179#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
180pub struct DisentanglementMetrics {
181    /// Mutual Information Gap (MIG) score
182    pub mig_score: Float,
183    /// Separated Attribute Predictability (SAP) score
184    pub sap_score: Float,
185    /// Modularity score
186    pub modularity_score: Float,
187    /// Compactness score
188    pub compactness_score: Float,
189}
190
191/// Main deep learning interpretability analyzer
192pub struct DeepLearningAnalyzer {
193    config: DeepLearningConfig,
194    concept_database: ConceptDatabase,
195}
196
197impl DeepLearningAnalyzer {
198    /// Create a new deep learning analyzer
199    pub fn new(config: DeepLearningConfig) -> Self {
200        Self {
201            config,
202            concept_database: ConceptDatabase::new(),
203        }
204    }
205
206    /// Perform TCAV analysis
207    pub fn compute_tcav<F>(
208        &self,
209        model_fn: F,
210        concept_examples: &ArrayView2<Float>,
211        random_examples: &ArrayView2<Float>,
212        test_examples: &ArrayView2<Float>,
213        target_class: usize,
214        layer_name: &str,
215    ) -> SklResult<TCAVResult>
216    where
217        F: Fn(&ArrayView2<Float>) -> SklResult<Array2<Float>>,
218    {
219        // 1. Get activations for concept and random examples
220        let concept_activations = model_fn(concept_examples)?;
221        let random_activations = model_fn(random_examples)?;
222
223        // 2. Train linear classifier to separate concept from random
224        let cav = self.train_concept_activation_vector(
225            &concept_activations.view(),
226            &random_activations.view(),
227            layer_name,
228        )?;
229
230        // 3. Compute directional derivatives for test examples
231        let test_activations = model_fn(test_examples)?;
232        let directional_derivatives = self.compute_directional_derivatives(
233            model_fn,
234            test_examples,
235            &cav.direction_vector.view(),
236            target_class,
237        )?;
238
239        // 4. Compute TCAV score
240        let positive_derivatives = directional_derivatives.iter().filter(|&&x| x > 0.0).count();
241
242        let tcav_score = positive_derivatives as Float / directional_derivatives.len() as Float;
243
244        // 5. Statistical testing
245        let (p_value, confidence_interval) =
246            self.compute_tcav_statistics(&directional_derivatives, tcav_score)?;
247
248        Ok(TCAVResult {
249            tcav_score,
250            p_value,
251            confidence_interval,
252            num_activated: positive_derivatives,
253            total_inputs: directional_derivatives.len(),
254            cav,
255        })
256    }
257
258    /// Perform network dissection
259    pub fn perform_network_dissection<F>(
260        &self,
261        model_fn: F,
262        probe_dataset: &ArrayView2<Float>,
263        concept_labels: &HashMap<String, Array1<bool>>,
264    ) -> SklResult<NetworkDissectionResult>
265    where
266        F: Fn(&ArrayView2<Float>) -> SklResult<HashMap<String, Array2<Float>>>,
267    {
268        let layer_activations = model_fn(probe_dataset)?;
269        let mut layer_concepts = HashMap::new();
270
271        // Analyze each layer
272        for (layer_name, activations) in layer_activations.iter() {
273            let detected_concepts =
274                self.detect_concepts_in_layer(activations, concept_labels, layer_name)?;
275            layer_concepts.insert(layer_name.clone(), detected_concepts);
276        }
277
278        // Compute overall interpretability score
279        let interpretability_score = self.compute_interpretability_score(&layer_concepts);
280
281        // Build concept hierarchy
282        let concept_hierarchy = self.build_concept_hierarchy(&layer_concepts)?;
283
284        // Compute disentanglement metrics
285        let disentanglement_metrics = self.compute_disentanglement_metrics(&layer_activations)?;
286
287        Ok(NetworkDissectionResult {
288            layer_concepts,
289            interpretability_score,
290            concept_hierarchy,
291            disentanglement_metrics,
292        })
293    }
294
295    /// Automated Concept Extraction (ACE)
296    pub fn extract_concepts_ace<F>(
297        &self,
298        model_fn: F,
299        images: &ArrayView3<Float>,
300        layer_name: &str,
301        num_concepts: usize,
302    ) -> SklResult<Vec<ConceptActivationVector>>
303    where
304        F: Fn(&ArrayView3<Float>) -> SklResult<Array2<Float>>,
305    {
306        // 1. Get layer activations
307        let activations = model_fn(images)?;
308
309        // 2. Segment images into superpixels
310        let segments = self.segment_images(images)?;
311
312        // 3. Cluster segments based on their activations
313        let concept_clusters = self.cluster_segments(&activations, &segments, num_concepts)?;
314
315        // 4. Create CAVs for each concept cluster
316        let mut cavs = Vec::new();
317        for (i, cluster) in concept_clusters.iter().enumerate() {
318            let concept_id = format!("ace_concept_{}", i);
319            let cav = self.create_cav_from_cluster(
320                concept_id,
321                layer_name.to_string(),
322                cluster,
323                &activations,
324            )?;
325            cavs.push(cav);
326        }
327
328        Ok(cavs)
329    }
330
331    fn train_concept_activation_vector(
332        &self,
333        concept_activations: &ArrayView2<Float>,
334        random_activations: &ArrayView2<Float>,
335        layer_name: &str,
336    ) -> SklResult<ConceptActivationVector> {
337        let n_concept = concept_activations.nrows();
338        let n_random = random_activations.nrows();
339        let n_features = concept_activations.ncols();
340
341        // Create labels: 1 for concept, 0 for random
342        let mut labels = Array1::zeros(n_concept + n_random);
343        for i in 0..n_concept {
344            labels[i] = 1.0;
345        }
346
347        // Combine activations
348        let mut combined_activations = Array2::zeros((n_concept + n_random, n_features));
349        for i in 0..n_concept {
350            combined_activations
351                .row_mut(i)
352                .assign(&concept_activations.row(i));
353        }
354        for i in 0..n_random {
355            combined_activations
356                .row_mut(n_concept + i)
357                .assign(&random_activations.row(i));
358        }
359
360        // Train linear SVM (simplified implementation)
361        let direction_vector =
362            self.train_linear_classifier(&combined_activations.view(), &labels.view())?;
363
364        // Compute accuracy using cross-validation
365        let accuracy = self.compute_classifier_accuracy(
366            &combined_activations.view(),
367            &labels.view(),
368            &direction_vector,
369        )?;
370
371        let mut cav = ConceptActivationVector::new(
372            "trained_concept".to_string(),
373            layer_name.to_string(),
374            direction_vector,
375        );
376        cav.accuracy = accuracy;
377
378        Ok(cav)
379    }
380
381    fn train_linear_classifier(
382        &self,
383        X: &ArrayView2<Float>,
384        y: &ArrayView1<Float>,
385    ) -> SklResult<Array1<Float>> {
386        let n_samples = X.nrows();
387        let n_features = X.ncols();
388
389        if n_samples != y.len() {
390            return Err(SklearsError::InvalidInput(
391                "Number of samples must match number of labels".to_string(),
392            ));
393        }
394
395        // Simple linear regression solution: w = (X^T X)^{-1} X^T y
396        // This is a simplified implementation
397        let mut weights = Array1::zeros(n_features);
398
399        // Use gradient descent for simplicity
400        let learning_rate = 0.01;
401        let max_iterations = 1000;
402
403        for _ in 0..max_iterations {
404            let predictions = X.dot(&weights);
405            let residuals = &predictions - y;
406            let gradient = X.t().dot(&residuals) / n_samples as Float;
407            weights = weights - learning_rate * gradient;
408        }
409
410        Ok(weights)
411    }
412
413    fn compute_classifier_accuracy(
414        &self,
415        X: &ArrayView2<Float>,
416        y: &ArrayView1<Float>,
417        weights: &Array1<Float>,
418    ) -> SklResult<Float> {
419        let predictions = X.dot(weights);
420        let binary_predictions = predictions.mapv(|x| if x > 0.0 { 1.0 } else { 0.0 });
421
422        let correct = binary_predictions
423            .iter()
424            .zip(y.iter())
425            .filter(|(&pred, &true_val)| (pred - true_val).abs() < 1e-6)
426            .count();
427
428        Ok(correct as Float / y.len() as Float)
429    }
430
431    fn compute_directional_derivatives<F>(
432        &self,
433        model_fn: F,
434        inputs: &ArrayView2<Float>,
435        direction: &ArrayView1<Float>,
436        target_class: usize,
437    ) -> SklResult<Array1<Float>>
438    where
439        F: Fn(&ArrayView2<Float>) -> SklResult<Array2<Float>>,
440    {
441        // Simplified gradient computation using finite differences
442        let epsilon = 1e-5;
443        let mut derivatives = Array1::zeros(inputs.nrows());
444
445        for (i, input) in inputs.outer_iter().enumerate() {
446            // Forward perturbation
447            let input_plus = input.to_owned();
448            let input_plus_view = input_plus.insert_axis(Axis(0));
449            let activation_plus = model_fn(&input_plus_view.view())?;
450
451            // Backward perturbation
452            let input_minus = input.to_owned();
453            let input_minus_view = input_minus.insert_axis(Axis(0));
454            let activation_minus = model_fn(&input_minus_view.view())?;
455
456            // Compute gradient approximation
457            let gradient_approx = (&activation_plus - &activation_minus) / (2.0 * epsilon);
458
459            // Directional derivative
460            derivatives[i] = gradient_approx.row(0).dot(direction);
461        }
462
463        Ok(derivatives)
464    }
465
466    fn compute_tcav_statistics(
467        &self,
468        directional_derivatives: &Array1<Float>,
469        tcav_score: Float,
470    ) -> SklResult<(Float, (Float, Float))> {
471        let n = directional_derivatives.len() as Float;
472
473        // Use normal approximation for p-value computation
474        let mean = 0.5; // Under null hypothesis
475        let variance = 0.25 / n; // Binomial variance / n
476        let std_error = variance.sqrt();
477
478        // Z-score
479        let z_score = (tcav_score - mean) / std_error;
480
481        // Two-tailed p-value (simplified)
482        let p_value = 2.0 * (1.0 - self.standard_normal_cdf(z_score.abs()));
483
484        // 95% confidence interval
485        let margin_of_error = 1.96 * std_error;
486        let confidence_interval = (
487            (tcav_score - margin_of_error).max(0.0),
488            (tcav_score + margin_of_error).min(1.0),
489        );
490
491        Ok((p_value, confidence_interval))
492    }
493
494    fn standard_normal_cdf(&self, x: Float) -> Float {
495        // Simplified approximation of standard normal CDF
496        0.5 * (1.0 + self.erf(x / (2.0_f64.sqrt() as Float)))
497    }
498
499    fn erf(&self, x: Float) -> Float {
500        // Simplified error function approximation
501        let a1 = 0.254829592;
502        let a2 = -0.284496736;
503        let a3 = 1.421413741;
504        let a4 = -1.453152027;
505        let a5 = 1.061405429;
506        let p = 0.3275911;
507
508        let sign = if x < 0.0 { -1.0 } else { 1.0 };
509        let x = x.abs();
510
511        let t = 1.0 / (1.0 + p * x);
512        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
513
514        sign * y
515    }
516
517    fn detect_concepts_in_layer(
518        &self,
519        activations: &Array2<Float>,
520        concept_labels: &HashMap<String, Array1<bool>>,
521        layer_name: &str,
522    ) -> SklResult<Vec<DetectedConcept>> {
523        let mut detected_concepts = Vec::new();
524
525        for (concept_name, labels) in concept_labels.iter() {
526            // For each unit in the layer, compute IoU with concept
527            for unit_idx in 0..activations.ncols() {
528                let unit_activations = activations.column(unit_idx);
529
530                // Find optimal threshold
531                let (threshold, iou_score) =
532                    self.find_optimal_threshold(&unit_activations, labels)?;
533
534                if iou_score > 0.04 {
535                    // Minimum IoU threshold from Network Dissection paper
536                    detected_concepts.push(DetectedConcept {
537                        name: concept_name.clone(),
538                        category: self.get_concept_category(concept_name),
539                        iou_score,
540                        detecting_units: vec![unit_idx],
541                        activation_threshold: threshold,
542                    });
543                }
544            }
545        }
546
547        Ok(detected_concepts)
548    }
549
550    fn find_optimal_threshold(
551        &self,
552        activations: &ArrayView1<Float>,
553        ground_truth: &Array1<bool>,
554    ) -> SklResult<(Float, Float)> {
555        let mut best_threshold = 0.0;
556        let mut best_iou = 0.0;
557
558        // Try different thresholds
559        let mut sorted_activations: Vec<Float> = activations.to_vec();
560        sorted_activations.sort_by(|a, b| a.partial_cmp(b).unwrap());
561
562        for &threshold in sorted_activations.iter() {
563            let predictions: Array1<bool> = activations.mapv(|x| x > threshold);
564            let iou = self.compute_iou(&predictions, ground_truth);
565
566            if iou > best_iou {
567                best_iou = iou;
568                best_threshold = threshold;
569            }
570        }
571
572        Ok((best_threshold, best_iou))
573    }
574
575    fn compute_iou(&self, predictions: &Array1<bool>, ground_truth: &Array1<bool>) -> Float {
576        let intersection = predictions
577            .iter()
578            .zip(ground_truth.iter())
579            .filter(|(&pred, &gt)| pred && gt)
580            .count() as Float;
581
582        let union = predictions
583            .iter()
584            .zip(ground_truth.iter())
585            .filter(|(&pred, &gt)| pred || gt)
586            .count() as Float;
587
588        if union == 0.0 {
589            0.0
590        } else {
591            intersection / union
592        }
593    }
594
595    fn get_concept_category(&self, concept_name: &str) -> String {
596        // Simple categorization based on concept name
597        if concept_name.contains("color") {
598            "color".to_string()
599        } else if concept_name.contains("texture") {
600            "texture".to_string()
601        } else if concept_name.contains("object") {
602            "object".to_string()
603        } else {
604            "other".to_string()
605        }
606    }
607
608    fn compute_interpretability_score(
609        &self,
610        layer_concepts: &HashMap<String, Vec<DetectedConcept>>,
611    ) -> Float {
612        if layer_concepts.is_empty() {
613            return 0.0;
614        }
615
616        let total_concepts: usize = layer_concepts.values().map(|concepts| concepts.len()).sum();
617        let weighted_iou: Float = layer_concepts
618            .values()
619            .flat_map(|concepts| concepts.iter())
620            .map(|concept| concept.iou_score)
621            .sum();
622
623        if total_concepts == 0 {
624            0.0
625        } else {
626            weighted_iou / total_concepts as Float
627        }
628    }
629
630    fn build_concept_hierarchy(
631        &self,
632        layer_concepts: &HashMap<String, Vec<DetectedConcept>>,
633    ) -> SklResult<ConceptHierarchy> {
634        let mut relationships = HashMap::new();
635        let mut abstraction_levels = HashMap::new();
636
637        // Simple hierarchy based on layer depth (deeper = more abstract)
638        let mut layer_names: Vec<String> = layer_concepts.keys().cloned().collect();
639        layer_names.sort();
640
641        for (level, layer_name) in layer_names.iter().enumerate() {
642            if let Some(concepts) = layer_concepts.get(layer_name) {
643                for concept in concepts {
644                    abstraction_levels.insert(concept.name.clone(), level);
645                    relationships.insert(concept.name.clone(), Vec::new());
646                }
647            }
648        }
649
650        // Create co-occurrence matrix (simplified)
651        let all_concepts: Vec<String> = abstraction_levels.keys().cloned().collect();
652        let n_concepts = all_concepts.len();
653        let co_occurrence = Array2::zeros((n_concepts, n_concepts));
654
655        Ok(ConceptHierarchy {
656            relationships,
657            abstraction_levels,
658            co_occurrence,
659        })
660    }
661
662    fn compute_disentanglement_metrics(
663        &self,
664        layer_activations: &HashMap<String, Array2<Float>>,
665    ) -> SklResult<DisentanglementMetrics> {
666        // Simplified disentanglement metrics computation
667        Ok(DisentanglementMetrics {
668            mig_score: 0.5,         // Placeholder
669            sap_score: 0.6,         // Placeholder
670            modularity_score: 0.7,  // Placeholder
671            compactness_score: 0.8, // Placeholder
672        })
673    }
674
675    fn segment_images(&self, images: &ArrayView3<Float>) -> SklResult<Vec<Vec<(usize, usize)>>> {
676        // Placeholder for image segmentation
677        // In practice, this would use superpixel segmentation algorithms
678        let mut segments = Vec::new();
679        for _ in 0..images.shape()[0] {
680            segments.push(vec![(0, 0), (1, 1)]); // Placeholder segments
681        }
682        Ok(segments)
683    }
684
685    fn cluster_segments(
686        &self,
687        activations: &Array2<Float>,
688        segments: &[Vec<(usize, usize)>],
689        num_concepts: usize,
690    ) -> SklResult<Vec<Vec<usize>>> {
691        // Placeholder for clustering implementation
692        // In practice, this would use k-means or other clustering algorithms
693        let mut clusters = Vec::new();
694        for i in 0..num_concepts {
695            clusters.push(vec![i, i + num_concepts]);
696        }
697        Ok(clusters)
698    }
699
700    fn create_cav_from_cluster(
701        &self,
702        concept_id: String,
703        layer_name: String,
704        cluster: &[usize],
705        activations: &Array2<Float>,
706    ) -> SklResult<ConceptActivationVector> {
707        // Compute mean activation for the cluster
708        let cluster_mean = if cluster.is_empty() {
709            Array1::zeros(activations.ncols())
710        } else {
711            let cluster_activations: Array2<Float> = cluster
712                .iter()
713                .map(|&idx| {
714                    if idx < activations.nrows() {
715                        activations.row(idx).to_owned()
716                    } else {
717                        Array1::zeros(activations.ncols())
718                    }
719                })
720                .collect::<Vec<_>>()
721                .into_iter()
722                .fold(Array2::zeros((0, activations.ncols())), |acc, row| {
723                    if acc.nrows() == 0 {
724                        Array2::from_shape_vec((1, row.len()), row.to_vec()).unwrap()
725                    } else {
726                        let new_shape = (acc.nrows() + 1, acc.ncols());
727                        let mut new_data = acc.into_raw_vec();
728                        new_data.extend(row.iter().cloned());
729                        Array2::from_shape_vec(new_shape, new_data).unwrap()
730                    }
731                });
732
733            cluster_activations.mean_axis(Axis(0)).unwrap()
734        };
735
736        Ok(ConceptActivationVector::new(
737            concept_id,
738            layer_name,
739            cluster_mean,
740        ))
741    }
742}
743
744/// Concept database for storing and managing learned concepts
745pub struct ConceptDatabase {
746    concepts: HashMap<String, ConceptActivationVector>,
747    concept_relationships: HashMap<String, Vec<String>>,
748}
749
750impl Default for ConceptDatabase {
751    fn default() -> Self {
752        Self::new()
753    }
754}
755
756impl ConceptDatabase {
757    pub fn new() -> Self {
758        Self {
759            concepts: HashMap::new(),
760            concept_relationships: HashMap::new(),
761        }
762    }
763
764    pub fn add_concept(&mut self, concept: ConceptActivationVector) {
765        self.concepts.insert(concept.concept_id.clone(), concept);
766    }
767
768    pub fn get_concept(&self, concept_id: &str) -> Option<&ConceptActivationVector> {
769        self.concepts.get(concept_id)
770    }
771
772    pub fn find_similar_concepts(&self, concept_id: &str, threshold: Float) -> Vec<String> {
773        if let Some(target_concept) = self.concepts.get(concept_id) {
774            self.concepts
775                .iter()
776                .filter(|(id, concept)| {
777                    *id != concept_id
778                        && self.compute_concept_similarity(
779                            &target_concept.direction_vector,
780                            &concept.direction_vector,
781                        ) > threshold
782                })
783                .map(|(id, _)| id.clone())
784                .collect()
785        } else {
786            Vec::new()
787        }
788    }
789
790    fn compute_concept_similarity(&self, v1: &Array1<Float>, v2: &Array1<Float>) -> Float {
791        // Cosine similarity
792        let dot_product = v1.dot(v2);
793        let norm1 = v1.dot(v1).sqrt();
794        let norm2 = v2.dot(v2).sqrt();
795
796        if norm1 == 0.0 || norm2 == 0.0 {
797            0.0
798        } else {
799            dot_product / (norm1 * norm2)
800        }
801    }
802}
803
804#[cfg(test)]
805mod tests {
806    use super::*;
807    // ✅ SciRS2 Policy Compliant Import
808    use scirs2_core::ndarray::Array;
809
810    #[test]
811    fn test_deep_learning_config_creation() {
812        let config = DeepLearningConfig::default();
813        assert_eq!(config.num_concepts, 20);
814        assert_eq!(config.activation_threshold, 0.5);
815        assert!(matches!(
816            config.concept_discovery_method,
817            ConceptDiscoveryMethod::ACE
818        ));
819    }
820
821    #[test]
822    fn test_concept_activation_vector() {
823        let direction = Array1::from_vec(vec![0.1, 0.2, 0.3]);
824        let cav = ConceptActivationVector::new(
825            "test_concept".to_string(),
826            "layer_1".to_string(),
827            direction,
828        );
829
830        assert_eq!(cav.concept_id, "test_concept");
831        assert_eq!(cav.layer_name, "layer_1");
832
833        let activation = Array1::from_vec(vec![1.0, 1.0, 1.0]);
834        let sensitivity = cav.compute_sensitivity(&activation.view());
835        assert!((sensitivity - 0.6).abs() < 1e-6); // 0.1 + 0.2 + 0.3
836    }
837
838    #[test]
839    fn test_concept_activation_check() {
840        let direction = Array1::from_vec(vec![1.0, 0.0, 0.0]);
841        let cav = ConceptActivationVector::new(
842            "test_concept".to_string(),
843            "layer_1".to_string(),
844            direction,
845        );
846
847        let high_activation = Array1::from_vec(vec![0.8, 0.1, 0.1]);
848        let low_activation = Array1::from_vec(vec![0.2, 0.1, 0.1]);
849
850        assert!(cav.is_activated(&high_activation.view(), 0.5));
851        assert!(!cav.is_activated(&low_activation.view(), 0.5));
852    }
853
854    #[test]
855    fn test_tcav_result() {
856        let direction = Array1::from_vec(vec![1.0, 0.0]);
857        let cav = ConceptActivationVector::new(
858            "test_concept".to_string(),
859            "layer_1".to_string(),
860            direction,
861        );
862
863        let result = TCAVResult {
864            tcav_score: 0.75,
865            p_value: 0.01,
866            confidence_interval: (0.65, 0.85),
867            num_activated: 15,
868            total_inputs: 20,
869            cav,
870        };
871
872        assert!(result.is_significant(0.05));
873        assert_eq!(result.effect_size_interpretation(), "Large effect");
874    }
875
876    #[test]
877    fn test_concept_database() {
878        let mut db = ConceptDatabase::new();
879
880        let direction = Array1::from_vec(vec![1.0, 0.0]);
881        let concept = ConceptActivationVector::new(
882            "test_concept".to_string(),
883            "layer_1".to_string(),
884            direction,
885        );
886
887        db.add_concept(concept);
888        assert!(db.get_concept("test_concept").is_some());
889        assert!(db.get_concept("nonexistent").is_none());
890    }
891
892    #[test]
893    fn test_deep_learning_analyzer_creation() {
894        let config = DeepLearningConfig::default();
895        let analyzer = DeepLearningAnalyzer::new(config);
896
897        assert_eq!(analyzer.config.num_concepts, 20);
898        assert!(analyzer.concept_database.concepts.is_empty());
899    }
900
901    #[test]
902    fn test_detected_concept() {
903        let concept = DetectedConcept {
904            name: "stripe_pattern".to_string(),
905            category: "texture".to_string(),
906            iou_score: 0.65,
907            detecting_units: vec![5, 12, 23],
908            activation_threshold: 0.4,
909        };
910
911        assert_eq!(concept.name, "stripe_pattern");
912        assert_eq!(concept.detecting_units.len(), 3);
913        assert!(concept.iou_score > 0.6);
914    }
915
916    #[test]
917    fn test_disentanglement_metrics() {
918        let metrics = DisentanglementMetrics {
919            mig_score: 0.8,
920            sap_score: 0.75,
921            modularity_score: 0.9,
922            compactness_score: 0.85,
923        };
924
925        assert!(metrics.mig_score > 0.7);
926        assert!(metrics.modularity_score > 0.8);
927    }
928}