Skip to main content

trustformers_debug/
interpretability_tools.rs

1//! # Interpretability Tools
2//!
3//! Comprehensive model interpretability toolkit including SHAP integration, LIME support,
4//! attention analysis, feature attribution, and counterfactual generation for TrustformeRS models.
5//!
6//! ## Refactoring Summary
7//!
8//! Previously this was a single 2,803-line file containing all interpretability functionality.
9//! It has been split into focused modules:
10//!
11//! - `interpretability/config.rs` - Configuration structures and enums (77 lines)
12//! - `interpretability/shap.rs` - SHAP analysis types and functionality (66 lines)
13//! - `interpretability/lime.rs` - LIME analysis types and functionality (78 lines)
14//! - `interpretability/attention.rs` - Attention analysis for transformers (426 lines)
15//! - `interpretability/attribution.rs` - Feature attribution methods (103 lines)
16//! - `interpretability/counterfactual.rs` - Counterfactual generation (191 lines)
17//! - `interpretability/analyzer.rs` - Main analyzer implementation (318 lines)
18//! - `interpretability/report.rs` - Reporting functionality (23 lines)
19//!
20//! This refactoring improves:
21//! - Code maintainability and readability
22//! - Module compilation times
23//! - Test isolation
24//! - Code reuse through focused modules
25//! - Developer experience when working on specific interpretability methods
26
27// TODO: Re-enable when interpretability module is implemented
28// Re-export the entire interpretability module
29// pub use self::interpretability::*;
30
31// Import the interpretability module
32// mod interpretability;
33
34// TODO: Re-enable when interpretability module is implemented
35// Convenience exports for backwards compatibility
36/*
37pub use interpretability::{
38    // Configuration
39    InterpretabilityConfig,
40    AttributionMethod,
41
42    // SHAP analysis
43    ShapAnalysisResult,
44    FeatureContribution,
45    ShapSummary,
46
47    // LIME analysis
48    LimeAnalysisResult,
49    FeatureImportance,
50    PerturbationAnalysis,
51    PerturbationResult,
52    NeighborhoodStats,
53
54    // Attention analysis
55    AttentionAnalysisResult,
56    AttentionLayerResult,
57    AttentionHeadResult,
58    TokenAttentionScore,
59    HeadSpecializationType,
60    AttentionPatterns,
61    DiagonalPattern,
62    VerticalPattern,
63    BlockPattern,
64    RepetitivePattern,
65    LayerAttentionPatterns,
66    LayerAttentionStats,
67    HeadSpecializationAnalysis,
68    HeadCluster,
69    SpecializationEvolution,
70    SpecializationTransition,
71    SpecializationTrend,
72    HeadRedundancyAnalysis,
73    RedundantHeadPair,
74    RedundancyType,
75    PruningRecommendation,
76    PruningImpact,
77    RiskLevel,
78    AttentionFlowAnalysis,
79    AttentionFlowPath,
80    LayerFlowStep,
81    FlowTransformation,
82    AttentionBottleneck,
83    BottleneckType,
84    FlowEfficiencyMetrics,
85    LayerFlowStats,
86    AttentionStatistics,
87    SparsityDistribution,
88    AttentionInsight,
89    InsightType,
90
91    // Feature attribution
92    FeatureAttributionResult,
93    AttributionMethodResult,
94    FeatureAttribution,
95    MethodAgreementAnalysis,
96    TopFeature,
97    AttributionVisualizationData,
98    TimelinePoint,
99    FeatureInteraction,
100    InteractionType,
101
102    // Counterfactual generation
103    CounterfactualResult,
104    Counterfactual,
105    FeatureChange,
106    ChangeDirection,
107    CounterfactualQualityMetrics,
108    FeatureSensitivityAnalysis,
109    InteractionEffect,
110    InteractionEffectType,
111    ThresholdAnalysis,
112    DecisionBoundaryAnalysis,
113    BoundaryCrossingPoint,
114    ActionableInsight,
115    ImplementationDifficulty,
116    TimeHorizon,
117
118    // Main analyzer
119    InterpretabilityAnalyzer,
120
121    // Reporting
122    InterpretabilityReport,
123};
124*/
125
126// Re-export tests for compatibility
127#[cfg(test)]
128mod tests {
129
130    use crate::{InterpretabilityAnalyzer, InterpretabilityConfig};
131    use std::collections::HashMap;
132
133    #[tokio::test]
134    async fn test_interpretability_analyzer_creation() {
135        let config = InterpretabilityConfig;
136        let _analyzer = InterpretabilityAnalyzer::new(config);
137        // Basic test to ensure analyzer can be created
138    }
139
140    #[tokio::test]
141    async fn test_shap_analysis() {
142        let config = InterpretabilityConfig;
143        let analyzer = InterpretabilityAnalyzer::new(config);
144
145        let mut instance = HashMap::new();
146        instance.insert("feature1".to_string(), 1.0);
147        instance.insert("feature2".to_string(), 2.0);
148
149        let model_predictions = vec![0.8, 0.7, 0.9];
150        let background_data = vec![{
151            let mut bg = HashMap::new();
152            bg.insert("feature1".to_string(), 0.5);
153            bg.insert("feature2".to_string(), 1.0);
154            bg
155        }];
156
157        let result = analyzer.analyze_shap(&instance, &model_predictions, &background_data).await;
158        assert!(result.is_ok());
159    }
160
161    #[tokio::test]
162    async fn test_lime_analysis() {
163        let config = InterpretabilityConfig;
164        let analyzer = InterpretabilityAnalyzer::new(config);
165
166        let mut instance = HashMap::new();
167        instance.insert("feature1".to_string(), 1.0);
168        instance.insert("feature2".to_string(), 2.0);
169
170        let model_fn =
171            Box::new(|input: &HashMap<String, f64>| -> f64 { input.values().sum::<f64>() * 0.1 });
172
173        let result = analyzer.analyze_lime(&instance, model_fn).await;
174        assert!(result.is_ok());
175    }
176}