1use crate::{Float, SklResult, SklearsError};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayView3, Axis};
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub struct DeepLearningConfig {
17 pub target_layers: Vec<String>,
19 pub num_concepts: usize,
21 pub activation_threshold: Float,
23 pub num_test_examples: usize,
25 pub random_seed: Option<u64>,
27 pub concept_discovery_method: ConceptDiscoveryMethod,
29}
30
31impl Default for DeepLearningConfig {
32 fn default() -> Self {
33 Self {
34 target_layers: vec!["layer_3".to_string(), "layer_5".to_string()],
35 num_concepts: 20,
36 activation_threshold: 0.5,
37 num_test_examples: 500,
38 random_seed: Some(42),
39 concept_discovery_method: ConceptDiscoveryMethod::ACE,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy)]
46#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
47pub enum ConceptDiscoveryMethod {
48 ACE,
50 TCAV,
52 CCAV,
54 NetworkDissection,
56}
57
58#[derive(Debug, Clone)]
60#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
61pub struct ConceptActivationVector {
62 pub concept_id: String,
64 pub layer_name: String,
66 pub direction_vector: Array1<Float>,
68 pub accuracy: Float,
70 pub p_value: Float,
72 pub activating_examples: Vec<usize>,
74}
75
76impl ConceptActivationVector {
77 pub fn new(concept_id: String, layer_name: String, direction_vector: Array1<Float>) -> Self {
79 Self {
80 concept_id,
81 layer_name,
82 direction_vector,
83 accuracy: 0.0,
84 p_value: 1.0,
85 activating_examples: Vec::new(),
86 }
87 }
88
89 pub fn compute_sensitivity(&self, activation: &ArrayView1<Float>) -> Float {
91 activation.dot(&self.direction_vector)
92 }
93
94 pub fn is_activated(&self, activation: &ArrayView1<Float>, threshold: Float) -> bool {
96 self.compute_sensitivity(activation) > threshold
97 }
98}
99
100#[derive(Debug, Clone)]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103pub struct TCAVResult {
104 pub tcav_score: Float,
106 pub p_value: Float,
108 pub confidence_interval: (Float, Float),
110 pub num_activated: usize,
112 pub total_inputs: usize,
114 pub cav: ConceptActivationVector,
116}
117
118impl TCAVResult {
119 pub fn is_significant(&self, alpha: Float) -> bool {
121 self.p_value < alpha
122 }
123
124 pub fn effect_size_interpretation(&self) -> String {
126 match self.tcav_score {
127 score if score < 0.1 => "Negligible effect".to_string(),
128 score if score < 0.3 => "Small effect".to_string(),
129 score if score < 0.5 => "Medium effect".to_string(),
130 _ => "Large effect".to_string(),
131 }
132 }
133}
134
135#[derive(Debug, Clone)]
137#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
138pub struct NetworkDissectionResult {
139 pub layer_concepts: HashMap<String, Vec<DetectedConcept>>,
141 pub interpretability_score: Float,
143 pub concept_hierarchy: ConceptHierarchy,
145 pub disentanglement_metrics: DisentanglementMetrics,
147}
148
149#[derive(Debug, Clone)]
151#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
152pub struct DetectedConcept {
153 pub name: String,
155 pub category: String,
157 pub iou_score: Float,
159 pub detecting_units: Vec<usize>,
161 pub activation_threshold: Float,
163}
164
165#[derive(Debug, Clone)]
167#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
168pub struct ConceptHierarchy {
169 pub relationships: HashMap<String, Vec<String>>,
171 pub abstraction_levels: HashMap<String, usize>,
173 pub co_occurrence: Array2<Float>,
175}
176
177#[derive(Debug, Clone)]
179#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
180pub struct DisentanglementMetrics {
181 pub mig_score: Float,
183 pub sap_score: Float,
185 pub modularity_score: Float,
187 pub compactness_score: Float,
189}
190
191pub struct DeepLearningAnalyzer {
193 config: DeepLearningConfig,
194 concept_database: ConceptDatabase,
195}
196
197impl DeepLearningAnalyzer {
198 pub fn new(config: DeepLearningConfig) -> Self {
200 Self {
201 config,
202 concept_database: ConceptDatabase::new(),
203 }
204 }
205
206 pub fn compute_tcav<F>(
208 &self,
209 model_fn: F,
210 concept_examples: &ArrayView2<Float>,
211 random_examples: &ArrayView2<Float>,
212 test_examples: &ArrayView2<Float>,
213 target_class: usize,
214 layer_name: &str,
215 ) -> SklResult<TCAVResult>
216 where
217 F: Fn(&ArrayView2<Float>) -> SklResult<Array2<Float>>,
218 {
219 let concept_activations = model_fn(concept_examples)?;
221 let random_activations = model_fn(random_examples)?;
222
223 let cav = self.train_concept_activation_vector(
225 &concept_activations.view(),
226 &random_activations.view(),
227 layer_name,
228 )?;
229
230 let test_activations = model_fn(test_examples)?;
232 let directional_derivatives = self.compute_directional_derivatives(
233 model_fn,
234 test_examples,
235 &cav.direction_vector.view(),
236 target_class,
237 )?;
238
239 let positive_derivatives = directional_derivatives.iter().filter(|&&x| x > 0.0).count();
241
242 let tcav_score = positive_derivatives as Float / directional_derivatives.len() as Float;
243
244 let (p_value, confidence_interval) =
246 self.compute_tcav_statistics(&directional_derivatives, tcav_score)?;
247
248 Ok(TCAVResult {
249 tcav_score,
250 p_value,
251 confidence_interval,
252 num_activated: positive_derivatives,
253 total_inputs: directional_derivatives.len(),
254 cav,
255 })
256 }
257
258 pub fn perform_network_dissection<F>(
260 &self,
261 model_fn: F,
262 probe_dataset: &ArrayView2<Float>,
263 concept_labels: &HashMap<String, Array1<bool>>,
264 ) -> SklResult<NetworkDissectionResult>
265 where
266 F: Fn(&ArrayView2<Float>) -> SklResult<HashMap<String, Array2<Float>>>,
267 {
268 let layer_activations = model_fn(probe_dataset)?;
269 let mut layer_concepts = HashMap::new();
270
271 for (layer_name, activations) in layer_activations.iter() {
273 let detected_concepts =
274 self.detect_concepts_in_layer(activations, concept_labels, layer_name)?;
275 layer_concepts.insert(layer_name.clone(), detected_concepts);
276 }
277
278 let interpretability_score = self.compute_interpretability_score(&layer_concepts);
280
281 let concept_hierarchy = self.build_concept_hierarchy(&layer_concepts)?;
283
284 let disentanglement_metrics = self.compute_disentanglement_metrics(&layer_activations)?;
286
287 Ok(NetworkDissectionResult {
288 layer_concepts,
289 interpretability_score,
290 concept_hierarchy,
291 disentanglement_metrics,
292 })
293 }
294
295 pub fn extract_concepts_ace<F>(
297 &self,
298 model_fn: F,
299 images: &ArrayView3<Float>,
300 layer_name: &str,
301 num_concepts: usize,
302 ) -> SklResult<Vec<ConceptActivationVector>>
303 where
304 F: Fn(&ArrayView3<Float>) -> SklResult<Array2<Float>>,
305 {
306 let activations = model_fn(images)?;
308
309 let segments = self.segment_images(images)?;
311
312 let concept_clusters = self.cluster_segments(&activations, &segments, num_concepts)?;
314
315 let mut cavs = Vec::new();
317 for (i, cluster) in concept_clusters.iter().enumerate() {
318 let concept_id = format!("ace_concept_{}", i);
319 let cav = self.create_cav_from_cluster(
320 concept_id,
321 layer_name.to_string(),
322 cluster,
323 &activations,
324 )?;
325 cavs.push(cav);
326 }
327
328 Ok(cavs)
329 }
330
331 fn train_concept_activation_vector(
332 &self,
333 concept_activations: &ArrayView2<Float>,
334 random_activations: &ArrayView2<Float>,
335 layer_name: &str,
336 ) -> SklResult<ConceptActivationVector> {
337 let n_concept = concept_activations.nrows();
338 let n_random = random_activations.nrows();
339 let n_features = concept_activations.ncols();
340
341 let mut labels = Array1::zeros(n_concept + n_random);
343 for i in 0..n_concept {
344 labels[i] = 1.0;
345 }
346
347 let mut combined_activations = Array2::zeros((n_concept + n_random, n_features));
349 for i in 0..n_concept {
350 combined_activations
351 .row_mut(i)
352 .assign(&concept_activations.row(i));
353 }
354 for i in 0..n_random {
355 combined_activations
356 .row_mut(n_concept + i)
357 .assign(&random_activations.row(i));
358 }
359
360 let direction_vector =
362 self.train_linear_classifier(&combined_activations.view(), &labels.view())?;
363
364 let accuracy = self.compute_classifier_accuracy(
366 &combined_activations.view(),
367 &labels.view(),
368 &direction_vector,
369 )?;
370
371 let mut cav = ConceptActivationVector::new(
372 "trained_concept".to_string(),
373 layer_name.to_string(),
374 direction_vector,
375 );
376 cav.accuracy = accuracy;
377
378 Ok(cav)
379 }
380
381 fn train_linear_classifier(
382 &self,
383 X: &ArrayView2<Float>,
384 y: &ArrayView1<Float>,
385 ) -> SklResult<Array1<Float>> {
386 let n_samples = X.nrows();
387 let n_features = X.ncols();
388
389 if n_samples != y.len() {
390 return Err(SklearsError::InvalidInput(
391 "Number of samples must match number of labels".to_string(),
392 ));
393 }
394
395 let mut weights = Array1::zeros(n_features);
398
399 let learning_rate = 0.01;
401 let max_iterations = 1000;
402
403 for _ in 0..max_iterations {
404 let predictions = X.dot(&weights);
405 let residuals = &predictions - y;
406 let gradient = X.t().dot(&residuals) / n_samples as Float;
407 weights = weights - learning_rate * gradient;
408 }
409
410 Ok(weights)
411 }
412
413 fn compute_classifier_accuracy(
414 &self,
415 X: &ArrayView2<Float>,
416 y: &ArrayView1<Float>,
417 weights: &Array1<Float>,
418 ) -> SklResult<Float> {
419 let predictions = X.dot(weights);
420 let binary_predictions = predictions.mapv(|x| if x > 0.0 { 1.0 } else { 0.0 });
421
422 let correct = binary_predictions
423 .iter()
424 .zip(y.iter())
425 .filter(|(&pred, &true_val)| (pred - true_val).abs() < 1e-6)
426 .count();
427
428 Ok(correct as Float / y.len() as Float)
429 }
430
431 fn compute_directional_derivatives<F>(
432 &self,
433 model_fn: F,
434 inputs: &ArrayView2<Float>,
435 direction: &ArrayView1<Float>,
436 target_class: usize,
437 ) -> SklResult<Array1<Float>>
438 where
439 F: Fn(&ArrayView2<Float>) -> SklResult<Array2<Float>>,
440 {
441 let epsilon = 1e-5;
443 let mut derivatives = Array1::zeros(inputs.nrows());
444
445 for (i, input) in inputs.outer_iter().enumerate() {
446 let input_plus = input.to_owned();
448 let input_plus_view = input_plus.insert_axis(Axis(0));
449 let activation_plus = model_fn(&input_plus_view.view())?;
450
451 let input_minus = input.to_owned();
453 let input_minus_view = input_minus.insert_axis(Axis(0));
454 let activation_minus = model_fn(&input_minus_view.view())?;
455
456 let gradient_approx = (&activation_plus - &activation_minus) / (2.0 * epsilon);
458
459 derivatives[i] = gradient_approx.row(0).dot(direction);
461 }
462
463 Ok(derivatives)
464 }
465
466 fn compute_tcav_statistics(
467 &self,
468 directional_derivatives: &Array1<Float>,
469 tcav_score: Float,
470 ) -> SklResult<(Float, (Float, Float))> {
471 let n = directional_derivatives.len() as Float;
472
473 let mean = 0.5; let variance = 0.25 / n; let std_error = variance.sqrt();
477
478 let z_score = (tcav_score - mean) / std_error;
480
481 let p_value = 2.0 * (1.0 - self.standard_normal_cdf(z_score.abs()));
483
484 let margin_of_error = 1.96 * std_error;
486 let confidence_interval = (
487 (tcav_score - margin_of_error).max(0.0),
488 (tcav_score + margin_of_error).min(1.0),
489 );
490
491 Ok((p_value, confidence_interval))
492 }
493
494 fn standard_normal_cdf(&self, x: Float) -> Float {
495 0.5 * (1.0 + self.erf(x / (2.0_f64.sqrt() as Float)))
497 }
498
499 fn erf(&self, x: Float) -> Float {
500 let a1 = 0.254829592;
502 let a2 = -0.284496736;
503 let a3 = 1.421413741;
504 let a4 = -1.453152027;
505 let a5 = 1.061405429;
506 let p = 0.3275911;
507
508 let sign = if x < 0.0 { -1.0 } else { 1.0 };
509 let x = x.abs();
510
511 let t = 1.0 / (1.0 + p * x);
512 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
513
514 sign * y
515 }
516
517 fn detect_concepts_in_layer(
518 &self,
519 activations: &Array2<Float>,
520 concept_labels: &HashMap<String, Array1<bool>>,
521 layer_name: &str,
522 ) -> SklResult<Vec<DetectedConcept>> {
523 let mut detected_concepts = Vec::new();
524
525 for (concept_name, labels) in concept_labels.iter() {
526 for unit_idx in 0..activations.ncols() {
528 let unit_activations = activations.column(unit_idx);
529
530 let (threshold, iou_score) =
532 self.find_optimal_threshold(&unit_activations, labels)?;
533
534 if iou_score > 0.04 {
535 detected_concepts.push(DetectedConcept {
537 name: concept_name.clone(),
538 category: self.get_concept_category(concept_name),
539 iou_score,
540 detecting_units: vec![unit_idx],
541 activation_threshold: threshold,
542 });
543 }
544 }
545 }
546
547 Ok(detected_concepts)
548 }
549
550 fn find_optimal_threshold(
551 &self,
552 activations: &ArrayView1<Float>,
553 ground_truth: &Array1<bool>,
554 ) -> SklResult<(Float, Float)> {
555 let mut best_threshold = 0.0;
556 let mut best_iou = 0.0;
557
558 let mut sorted_activations: Vec<Float> = activations.to_vec();
560 sorted_activations.sort_by(|a, b| a.partial_cmp(b).unwrap());
561
562 for &threshold in sorted_activations.iter() {
563 let predictions: Array1<bool> = activations.mapv(|x| x > threshold);
564 let iou = self.compute_iou(&predictions, ground_truth);
565
566 if iou > best_iou {
567 best_iou = iou;
568 best_threshold = threshold;
569 }
570 }
571
572 Ok((best_threshold, best_iou))
573 }
574
575 fn compute_iou(&self, predictions: &Array1<bool>, ground_truth: &Array1<bool>) -> Float {
576 let intersection = predictions
577 .iter()
578 .zip(ground_truth.iter())
579 .filter(|(&pred, >)| pred && gt)
580 .count() as Float;
581
582 let union = predictions
583 .iter()
584 .zip(ground_truth.iter())
585 .filter(|(&pred, >)| pred || gt)
586 .count() as Float;
587
588 if union == 0.0 {
589 0.0
590 } else {
591 intersection / union
592 }
593 }
594
595 fn get_concept_category(&self, concept_name: &str) -> String {
596 if concept_name.contains("color") {
598 "color".to_string()
599 } else if concept_name.contains("texture") {
600 "texture".to_string()
601 } else if concept_name.contains("object") {
602 "object".to_string()
603 } else {
604 "other".to_string()
605 }
606 }
607
608 fn compute_interpretability_score(
609 &self,
610 layer_concepts: &HashMap<String, Vec<DetectedConcept>>,
611 ) -> Float {
612 if layer_concepts.is_empty() {
613 return 0.0;
614 }
615
616 let total_concepts: usize = layer_concepts.values().map(|concepts| concepts.len()).sum();
617 let weighted_iou: Float = layer_concepts
618 .values()
619 .flat_map(|concepts| concepts.iter())
620 .map(|concept| concept.iou_score)
621 .sum();
622
623 if total_concepts == 0 {
624 0.0
625 } else {
626 weighted_iou / total_concepts as Float
627 }
628 }
629
630 fn build_concept_hierarchy(
631 &self,
632 layer_concepts: &HashMap<String, Vec<DetectedConcept>>,
633 ) -> SklResult<ConceptHierarchy> {
634 let mut relationships = HashMap::new();
635 let mut abstraction_levels = HashMap::new();
636
637 let mut layer_names: Vec<String> = layer_concepts.keys().cloned().collect();
639 layer_names.sort();
640
641 for (level, layer_name) in layer_names.iter().enumerate() {
642 if let Some(concepts) = layer_concepts.get(layer_name) {
643 for concept in concepts {
644 abstraction_levels.insert(concept.name.clone(), level);
645 relationships.insert(concept.name.clone(), Vec::new());
646 }
647 }
648 }
649
650 let all_concepts: Vec<String> = abstraction_levels.keys().cloned().collect();
652 let n_concepts = all_concepts.len();
653 let co_occurrence = Array2::zeros((n_concepts, n_concepts));
654
655 Ok(ConceptHierarchy {
656 relationships,
657 abstraction_levels,
658 co_occurrence,
659 })
660 }
661
662 fn compute_disentanglement_metrics(
663 &self,
664 layer_activations: &HashMap<String, Array2<Float>>,
665 ) -> SklResult<DisentanglementMetrics> {
666 Ok(DisentanglementMetrics {
668 mig_score: 0.5, sap_score: 0.6, modularity_score: 0.7, compactness_score: 0.8, })
673 }
674
675 fn segment_images(&self, images: &ArrayView3<Float>) -> SklResult<Vec<Vec<(usize, usize)>>> {
676 let mut segments = Vec::new();
679 for _ in 0..images.shape()[0] {
680 segments.push(vec![(0, 0), (1, 1)]); }
682 Ok(segments)
683 }
684
685 fn cluster_segments(
686 &self,
687 activations: &Array2<Float>,
688 segments: &[Vec<(usize, usize)>],
689 num_concepts: usize,
690 ) -> SklResult<Vec<Vec<usize>>> {
691 let mut clusters = Vec::new();
694 for i in 0..num_concepts {
695 clusters.push(vec![i, i + num_concepts]);
696 }
697 Ok(clusters)
698 }
699
700 fn create_cav_from_cluster(
701 &self,
702 concept_id: String,
703 layer_name: String,
704 cluster: &[usize],
705 activations: &Array2<Float>,
706 ) -> SklResult<ConceptActivationVector> {
707 let cluster_mean = if cluster.is_empty() {
709 Array1::zeros(activations.ncols())
710 } else {
711 let cluster_activations: Array2<Float> = cluster
712 .iter()
713 .map(|&idx| {
714 if idx < activations.nrows() {
715 activations.row(idx).to_owned()
716 } else {
717 Array1::zeros(activations.ncols())
718 }
719 })
720 .collect::<Vec<_>>()
721 .into_iter()
722 .fold(Array2::zeros((0, activations.ncols())), |acc, row| {
723 if acc.nrows() == 0 {
724 Array2::from_shape_vec((1, row.len()), row.to_vec()).unwrap()
725 } else {
726 let new_shape = (acc.nrows() + 1, acc.ncols());
727 let mut new_data = acc.into_raw_vec();
728 new_data.extend(row.iter().cloned());
729 Array2::from_shape_vec(new_shape, new_data).unwrap()
730 }
731 });
732
733 cluster_activations.mean_axis(Axis(0)).unwrap()
734 };
735
736 Ok(ConceptActivationVector::new(
737 concept_id,
738 layer_name,
739 cluster_mean,
740 ))
741 }
742}
743
744pub struct ConceptDatabase {
746 concepts: HashMap<String, ConceptActivationVector>,
747 concept_relationships: HashMap<String, Vec<String>>,
748}
749
750impl Default for ConceptDatabase {
751 fn default() -> Self {
752 Self::new()
753 }
754}
755
756impl ConceptDatabase {
757 pub fn new() -> Self {
758 Self {
759 concepts: HashMap::new(),
760 concept_relationships: HashMap::new(),
761 }
762 }
763
764 pub fn add_concept(&mut self, concept: ConceptActivationVector) {
765 self.concepts.insert(concept.concept_id.clone(), concept);
766 }
767
768 pub fn get_concept(&self, concept_id: &str) -> Option<&ConceptActivationVector> {
769 self.concepts.get(concept_id)
770 }
771
772 pub fn find_similar_concepts(&self, concept_id: &str, threshold: Float) -> Vec<String> {
773 if let Some(target_concept) = self.concepts.get(concept_id) {
774 self.concepts
775 .iter()
776 .filter(|(id, concept)| {
777 *id != concept_id
778 && self.compute_concept_similarity(
779 &target_concept.direction_vector,
780 &concept.direction_vector,
781 ) > threshold
782 })
783 .map(|(id, _)| id.clone())
784 .collect()
785 } else {
786 Vec::new()
787 }
788 }
789
790 fn compute_concept_similarity(&self, v1: &Array1<Float>, v2: &Array1<Float>) -> Float {
791 let dot_product = v1.dot(v2);
793 let norm1 = v1.dot(v1).sqrt();
794 let norm2 = v2.dot(v2).sqrt();
795
796 if norm1 == 0.0 || norm2 == 0.0 {
797 0.0
798 } else {
799 dot_product / (norm1 * norm2)
800 }
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
807 use scirs2_core::ndarray::Array;
809
810 #[test]
811 fn test_deep_learning_config_creation() {
812 let config = DeepLearningConfig::default();
813 assert_eq!(config.num_concepts, 20);
814 assert_eq!(config.activation_threshold, 0.5);
815 assert!(matches!(
816 config.concept_discovery_method,
817 ConceptDiscoveryMethod::ACE
818 ));
819 }
820
821 #[test]
822 fn test_concept_activation_vector() {
823 let direction = Array1::from_vec(vec![0.1, 0.2, 0.3]);
824 let cav = ConceptActivationVector::new(
825 "test_concept".to_string(),
826 "layer_1".to_string(),
827 direction,
828 );
829
830 assert_eq!(cav.concept_id, "test_concept");
831 assert_eq!(cav.layer_name, "layer_1");
832
833 let activation = Array1::from_vec(vec![1.0, 1.0, 1.0]);
834 let sensitivity = cav.compute_sensitivity(&activation.view());
835 assert!((sensitivity - 0.6).abs() < 1e-6); }
837
838 #[test]
839 fn test_concept_activation_check() {
840 let direction = Array1::from_vec(vec![1.0, 0.0, 0.0]);
841 let cav = ConceptActivationVector::new(
842 "test_concept".to_string(),
843 "layer_1".to_string(),
844 direction,
845 );
846
847 let high_activation = Array1::from_vec(vec![0.8, 0.1, 0.1]);
848 let low_activation = Array1::from_vec(vec![0.2, 0.1, 0.1]);
849
850 assert!(cav.is_activated(&high_activation.view(), 0.5));
851 assert!(!cav.is_activated(&low_activation.view(), 0.5));
852 }
853
854 #[test]
855 fn test_tcav_result() {
856 let direction = Array1::from_vec(vec![1.0, 0.0]);
857 let cav = ConceptActivationVector::new(
858 "test_concept".to_string(),
859 "layer_1".to_string(),
860 direction,
861 );
862
863 let result = TCAVResult {
864 tcav_score: 0.75,
865 p_value: 0.01,
866 confidence_interval: (0.65, 0.85),
867 num_activated: 15,
868 total_inputs: 20,
869 cav,
870 };
871
872 assert!(result.is_significant(0.05));
873 assert_eq!(result.effect_size_interpretation(), "Large effect");
874 }
875
876 #[test]
877 fn test_concept_database() {
878 let mut db = ConceptDatabase::new();
879
880 let direction = Array1::from_vec(vec![1.0, 0.0]);
881 let concept = ConceptActivationVector::new(
882 "test_concept".to_string(),
883 "layer_1".to_string(),
884 direction,
885 );
886
887 db.add_concept(concept);
888 assert!(db.get_concept("test_concept").is_some());
889 assert!(db.get_concept("nonexistent").is_none());
890 }
891
892 #[test]
893 fn test_deep_learning_analyzer_creation() {
894 let config = DeepLearningConfig::default();
895 let analyzer = DeepLearningAnalyzer::new(config);
896
897 assert_eq!(analyzer.config.num_concepts, 20);
898 assert!(analyzer.concept_database.concepts.is_empty());
899 }
900
901 #[test]
902 fn test_detected_concept() {
903 let concept = DetectedConcept {
904 name: "stripe_pattern".to_string(),
905 category: "texture".to_string(),
906 iou_score: 0.65,
907 detecting_units: vec![5, 12, 23],
908 activation_threshold: 0.4,
909 };
910
911 assert_eq!(concept.name, "stripe_pattern");
912 assert_eq!(concept.detecting_units.len(), 3);
913 assert!(concept.iou_score > 0.6);
914 }
915
916 #[test]
917 fn test_disentanglement_metrics() {
918 let metrics = DisentanglementMetrics {
919 mig_score: 0.8,
920 sap_score: 0.75,
921 modularity_score: 0.9,
922 compactness_score: 0.85,
923 };
924
925 assert!(metrics.mig_score > 0.7);
926 assert!(metrics.modularity_score > 0.8);
927 }
928}