Skip to main content

trustformers_debug/
behavior_analysis.rs

1//! Behavior Analysis
2//!
3//! Advanced analysis tools for understanding neural network behavior including
4//! input sensitivity, feature importance, and neuron activation patterns.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9
10/// Configuration for behavior analysis
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct BehaviorAnalysisConfig {
13    /// Enable input sensitivity analysis
14    pub enable_input_sensitivity: bool,
15    /// Enable feature importance calculations
16    pub enable_feature_importance: bool,
17    /// Enable neuron activation pattern analysis
18    pub enable_activation_patterns: bool,
19    /// Enable dead neuron detection
20    pub enable_dead_neuron_detection: bool,
21    /// Enable correlation analysis
22    pub enable_correlation_analysis: bool,
23    /// Threshold for dead neuron detection (activation below this value)
24    pub dead_neuron_threshold: f32,
25    /// Number of samples for sensitivity analysis
26    pub sensitivity_samples: usize,
27    /// Perturbation magnitude for sensitivity analysis
28    pub perturbation_magnitude: f32,
29    /// Correlation threshold for significance
30    pub correlation_threshold: f32,
31}
32
33impl Default for BehaviorAnalysisConfig {
34    fn default() -> Self {
35        Self {
36            enable_input_sensitivity: true,
37            enable_feature_importance: true,
38            enable_activation_patterns: true,
39            enable_dead_neuron_detection: true,
40            enable_correlation_analysis: true,
41            dead_neuron_threshold: 1e-6,
42            sensitivity_samples: 100,
43            perturbation_magnitude: 0.01,
44            correlation_threshold: 0.5,
45        }
46    }
47}
48
49/// Input sensitivity analysis results
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct InputSensitivity {
52    pub input_dimension: usize,
53    pub sensitivity_score: f32,
54    pub gradient_magnitude: f32,
55    pub perturbation_impact: f32,
56    pub rank: usize,
57}
58
59/// Feature importance analysis results
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct FeatureImportance {
62    pub feature_id: String,
63    pub importance_score: f32,
64    pub attribution_method: AttributionMethod,
65    pub confidence: f32,
66    pub rank: usize,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub enum AttributionMethod {
71    GradientBased,
72    PermutationImportance,
73    ShapleySampling,
74    IntegratedGradients,
75    LimeApproximation,
76}
77
78/// Neuron activation pattern information
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct NeuronActivationPattern {
81    pub layer_id: String,
82    pub neuron_id: usize,
83    pub activation_statistics: ActivationStatistics,
84    pub pattern_type: ActivationPatternType,
85    pub stability_score: f32,
86    pub selectivity_score: f32,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ActivationStatistics {
91    pub mean: f32,
92    pub std: f32,
93    pub min: f32,
94    pub max: f32,
95    pub percentile_25: f32,
96    pub percentile_75: f32,
97    pub skewness: f32,
98    pub kurtosis: f32,
99    pub sparsity: f32, // Fraction of near-zero activations
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum ActivationPatternType {
104    Normal,
105    Saturated,
106    Dead,
107    Oscillating,
108    Sparse,
109    Dense,
110    Bipolar,
111}
112
113/// Dead neuron detection results
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct DeadNeuronInfo {
116    pub layer_id: String,
117    pub neuron_id: usize,
118    pub activation_level: f32,
119    pub dead_probability: f32,
120    pub suggested_action: NeuronRepairAction,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub enum NeuronRepairAction {
125    Reinitialize,
126    AdjustLearningRate,
127    ChangeActivationFunction,
128    AddNoise,
129    Skip, // Neuron is functioning normally
130}
131
132/// Correlation analysis results
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct CorrelationAnalysis {
135    pub correlation_matrix: Vec<Vec<f32>>,
136    pub significant_correlations: Vec<CorrelationPair>,
137    pub redundant_features: Vec<FeatureGroup>,
138    pub independent_features: Vec<usize>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct CorrelationPair {
143    pub feature_a: usize,
144    pub feature_b: usize,
145    pub correlation: f32,
146    pub p_value: f32,
147    pub relationship_type: CorrelationType,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub enum CorrelationType {
152    Strong,
153    Moderate,
154    Weak,
155    None,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct FeatureGroup {
160    pub features: Vec<usize>,
161    pub average_correlation: f32,
162    pub group_importance: f32,
163}
164
165/// Comprehensive behavior analysis report
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct BehaviorAnalysisReport {
168    pub input_sensitivities: Vec<InputSensitivity>,
169    pub feature_importances: Vec<FeatureImportance>,
170    pub activation_patterns: Vec<NeuronActivationPattern>,
171    pub dead_neurons: Vec<DeadNeuronInfo>,
172    pub correlation_analysis: Option<CorrelationAnalysis>,
173    pub behavior_summary: BehaviorSummary,
174    pub recommendations: Vec<BehaviorRecommendation>,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct BehaviorSummary {
179    pub total_neurons_analyzed: usize,
180    pub dead_neuron_percentage: f32,
181    pub average_activation_sparsity: f32,
182    pub feature_distribution_entropy: f32,
183    pub model_stability_score: f32,
184    pub interpretability_score: f32,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct BehaviorRecommendation {
189    pub category: RecommendationCategory,
190    pub priority: Priority,
191    pub description: String,
192    pub implementation: String,
193    pub expected_impact: f32,
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub enum RecommendationCategory {
198    Architecture,
199    Training,
200    Initialization,
201    Regularization,
202    DataPreprocessing,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub enum Priority {
207    Critical,
208    High,
209    Medium,
210    Low,
211}
212
213/// Behavior analyzer
214#[derive(Debug)]
215pub struct BehaviorAnalyzer {
216    config: BehaviorAnalysisConfig,
217    activation_history: HashMap<String, Vec<Vec<f32>>>,
218    input_gradients: HashMap<String, Vec<f32>>,
219    feature_attributions: HashMap<String, FeatureImportance>,
220    analysis_cache: HashMap<String, BehaviorAnalysisReport>,
221}
222
223impl BehaviorAnalyzer {
224    /// Create a new behavior analyzer
225    pub fn new(config: BehaviorAnalysisConfig) -> Self {
226        Self {
227            config,
228            activation_history: HashMap::new(),
229            input_gradients: HashMap::new(),
230            feature_attributions: HashMap::new(),
231            analysis_cache: HashMap::new(),
232        }
233    }
234
235    /// Record neuron activations for analysis
236    pub fn record_activations(&mut self, layer_id: String, activations: Vec<f32>) {
237        self.activation_history.entry(layer_id).or_default().push(activations);
238    }
239
240    /// Record input gradients for sensitivity analysis
241    pub fn record_input_gradients(&mut self, input_id: String, gradients: Vec<f32>) {
242        self.input_gradients.insert(input_id, gradients);
243    }
244
245    /// Perform comprehensive behavior analysis
246    pub async fn analyze(&mut self) -> Result<BehaviorAnalysisReport> {
247        let mut report = BehaviorAnalysisReport {
248            input_sensitivities: Vec::new(),
249            feature_importances: Vec::new(),
250            activation_patterns: Vec::new(),
251            dead_neurons: Vec::new(),
252            correlation_analysis: None,
253            behavior_summary: BehaviorSummary {
254                total_neurons_analyzed: 0,
255                dead_neuron_percentage: 0.0,
256                average_activation_sparsity: 0.0,
257                feature_distribution_entropy: 0.0,
258                model_stability_score: 0.0,
259                interpretability_score: 0.0,
260            },
261            recommendations: Vec::new(),
262        };
263
264        if self.config.enable_input_sensitivity {
265            report.input_sensitivities = self.analyze_input_sensitivity().await?;
266        }
267
268        if self.config.enable_feature_importance {
269            report.feature_importances = self.calculate_feature_importance().await?;
270        }
271
272        if self.config.enable_activation_patterns {
273            report.activation_patterns = self.analyze_activation_patterns().await?;
274        }
275
276        if self.config.enable_dead_neuron_detection {
277            report.dead_neurons = self.detect_dead_neurons().await?;
278        }
279
280        if self.config.enable_correlation_analysis {
281            report.correlation_analysis = Some(self.perform_correlation_analysis().await?);
282        }
283
284        self.generate_behavior_summary(&mut report);
285        self.generate_recommendations(&mut report);
286
287        Ok(report)
288    }
289
290    /// Analyze input sensitivity using gradient-based methods
291    async fn analyze_input_sensitivity(&self) -> Result<Vec<InputSensitivity>> {
292        let mut sensitivities = Vec::new();
293
294        for gradients in self.input_gradients.values() {
295            for (dim, &gradient) in gradients.iter().enumerate() {
296                let sensitivity_score = gradient.abs();
297                let gradient_magnitude = gradient.abs();
298
299                // Simulate perturbation impact (would normally require model re-evaluation)
300                let perturbation_impact = self.estimate_perturbation_impact(gradient, dim);
301
302                sensitivities.push(InputSensitivity {
303                    input_dimension: dim,
304                    sensitivity_score,
305                    gradient_magnitude,
306                    perturbation_impact,
307                    rank: 0, // Will be set after sorting
308                });
309            }
310        }
311
312        // Sort by sensitivity score and assign ranks
313        sensitivities.sort_by(|a, b| {
314            b.sensitivity_score
315                .partial_cmp(&a.sensitivity_score)
316                .unwrap_or(std::cmp::Ordering::Equal)
317        });
318        for (rank, sensitivity) in sensitivities.iter_mut().enumerate() {
319            sensitivity.rank = rank + 1;
320        }
321
322        Ok(sensitivities)
323    }
324
325    /// Estimate perturbation impact (simplified version)
326    fn estimate_perturbation_impact(&self, gradient: f32, _dimension: usize) -> f32 {
327        // Simplified estimation: perturbation impact is proportional to gradient magnitude
328        gradient.abs() * self.config.perturbation_magnitude
329    }
330
331    /// Calculate feature importance using multiple methods
332    async fn calculate_feature_importance(&self) -> Result<Vec<FeatureImportance>> {
333        let mut importances = Vec::new();
334
335        // Gradient-based importance
336        for (input_id, gradients) in &self.input_gradients {
337            let total_gradient = gradients.iter().map(|g| g.abs()).sum::<f32>();
338            let importance_score = total_gradient / gradients.len() as f32;
339
340            importances.push(FeatureImportance {
341                feature_id: input_id.clone(),
342                importance_score,
343                attribution_method: AttributionMethod::GradientBased,
344                confidence: self.calculate_attribution_confidence(importance_score),
345                rank: 0,
346            });
347        }
348
349        // Sort by importance and assign ranks
350        importances.sort_by(|a, b| {
351            b.importance_score
352                .partial_cmp(&a.importance_score)
353                .unwrap_or(std::cmp::Ordering::Equal)
354        });
355        for (rank, importance) in importances.iter_mut().enumerate() {
356            importance.rank = rank + 1;
357        }
358
359        Ok(importances)
360    }
361
362    /// Calculate confidence in attribution score
363    fn calculate_attribution_confidence(&self, score: f32) -> f32 {
364        // Simple confidence based on score magnitude
365        (score.tanh() * 0.5 + 0.5).min(1.0)
366    }
367
368    /// Analyze neuron activation patterns
369    async fn analyze_activation_patterns(&self) -> Result<Vec<NeuronActivationPattern>> {
370        let mut patterns = Vec::new();
371
372        for (layer_id, activation_history) in &self.activation_history {
373            if activation_history.is_empty() {
374                continue;
375            }
376
377            let neuron_count = activation_history[0].len();
378
379            for neuron_id in 0..neuron_count {
380                let neuron_activations: Vec<f32> = activation_history
381                    .iter()
382                    .map(|batch| batch.get(neuron_id).copied().unwrap_or(0.0))
383                    .collect();
384
385                let statistics = self.compute_activation_statistics(&neuron_activations);
386                let pattern_type = self.classify_activation_pattern(&statistics);
387                let stability_score = self.compute_stability_score(&neuron_activations);
388                let selectivity_score = self.compute_selectivity_score(&neuron_activations);
389
390                patterns.push(NeuronActivationPattern {
391                    layer_id: layer_id.clone(),
392                    neuron_id,
393                    activation_statistics: statistics,
394                    pattern_type,
395                    stability_score,
396                    selectivity_score,
397                });
398            }
399        }
400
401        Ok(patterns)
402    }
403
404    /// Compute detailed activation statistics
405    fn compute_activation_statistics(&self, activations: &[f32]) -> ActivationStatistics {
406        if activations.is_empty() {
407            return ActivationStatistics {
408                mean: 0.0,
409                std: 0.0,
410                min: 0.0,
411                max: 0.0,
412                percentile_25: 0.0,
413                percentile_75: 0.0,
414                skewness: 0.0,
415                kurtosis: 0.0,
416                sparsity: 1.0,
417            };
418        }
419
420        let mean = activations.iter().sum::<f32>() / activations.len() as f32;
421        let variance =
422            activations.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / activations.len() as f32;
423        let std = variance.sqrt();
424
425        let mut sorted_activations = activations.to_vec();
426        sorted_activations.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
427
428        let min = sorted_activations[0];
429        let max = sorted_activations[sorted_activations.len() - 1];
430        let percentile_25 = sorted_activations[sorted_activations.len() / 4];
431        let percentile_75 = sorted_activations[3 * sorted_activations.len() / 4];
432
433        // Calculate skewness and kurtosis
434        let skewness = if std > 0.0 {
435            activations.iter().map(|&x| ((x - mean) / std).powi(3)).sum::<f32>()
436                / activations.len() as f32
437        } else {
438            0.0
439        };
440
441        let kurtosis = if std > 0.0 {
442            activations.iter().map(|&x| ((x - mean) / std).powi(4)).sum::<f32>()
443                / activations.len() as f32
444                - 3.0
445        } else {
446            0.0
447        };
448
449        // Calculate sparsity (fraction of near-zero activations)
450        let near_zero_count = activations
451            .iter()
452            .filter(|&&x| x.abs() < self.config.dead_neuron_threshold)
453            .count();
454        let sparsity = near_zero_count as f32 / activations.len() as f32;
455
456        ActivationStatistics {
457            mean,
458            std,
459            min,
460            max,
461            percentile_25,
462            percentile_75,
463            skewness,
464            kurtosis,
465            sparsity,
466        }
467    }
468
469    /// Classify activation pattern type
470    fn classify_activation_pattern(&self, stats: &ActivationStatistics) -> ActivationPatternType {
471        if stats.sparsity > 0.9 {
472            ActivationPatternType::Dead
473        } else if stats.sparsity > 0.7 {
474            ActivationPatternType::Sparse
475        } else if stats.max > 0.95 && stats.mean > 0.8 {
476            ActivationPatternType::Saturated
477        } else if stats.std / stats.mean.abs().max(1e-8) > 2.0 {
478            ActivationPatternType::Oscillating
479        } else if stats.mean.abs() > 0.1 && stats.mean * stats.min < 0.0 {
480            ActivationPatternType::Bipolar
481        } else if stats.sparsity < 0.3 {
482            ActivationPatternType::Dense
483        } else {
484            ActivationPatternType::Normal
485        }
486    }
487
488    /// Compute stability score for neuron activations
489    fn compute_stability_score(&self, activations: &[f32]) -> f32 {
490        if activations.len() < 2 {
491            return 0.0;
492        }
493
494        let mean = activations.iter().sum::<f32>() / activations.len() as f32;
495        let variance =
496            activations.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / activations.len() as f32;
497
498        // Stability is inverse of coefficient of variation
499        if mean.abs() > 1e-8 {
500            1.0 / (1.0 + variance.sqrt() / mean.abs())
501        } else {
502            0.0
503        }
504    }
505
506    /// Compute selectivity score (how selective the neuron is)
507    fn compute_selectivity_score(&self, activations: &[f32]) -> f32 {
508        if activations.is_empty() {
509            return 0.0;
510        }
511
512        // Selectivity based on activation distribution
513        let max_activation = activations.iter().fold(0.0f32, |a, &b| a.max(b.abs()));
514        let mean_activation =
515            activations.iter().map(|x| x.abs()).sum::<f32>() / activations.len() as f32;
516
517        if max_activation > 1e-8 {
518            1.0 - (mean_activation / max_activation)
519        } else {
520            0.0
521        }
522    }
523
524    /// Detect dead neurons
525    async fn detect_dead_neurons(&self) -> Result<Vec<DeadNeuronInfo>> {
526        let mut dead_neurons = Vec::new();
527
528        for (layer_id, activation_history) in &self.activation_history {
529            if activation_history.is_empty() {
530                continue;
531            }
532
533            let neuron_count = activation_history[0].len();
534
535            for neuron_id in 0..neuron_count {
536                let neuron_activations: Vec<f32> = activation_history
537                    .iter()
538                    .map(|batch| batch.get(neuron_id).copied().unwrap_or(0.0))
539                    .collect();
540
541                let activation_level = neuron_activations.iter().map(|x| x.abs()).sum::<f32>()
542                    / neuron_activations.len() as f32;
543
544                let dead_probability = if activation_level < self.config.dead_neuron_threshold {
545                    1.0 - (activation_level / self.config.dead_neuron_threshold)
546                } else {
547                    0.0
548                };
549
550                if dead_probability > 0.5 {
551                    let suggested_action =
552                        self.suggest_neuron_repair_action(activation_level, &neuron_activations);
553
554                    dead_neurons.push(DeadNeuronInfo {
555                        layer_id: layer_id.clone(),
556                        neuron_id,
557                        activation_level,
558                        dead_probability,
559                        suggested_action,
560                    });
561                }
562            }
563        }
564
565        Ok(dead_neurons)
566    }
567
568    /// Suggest repair action for dead neurons
569    fn suggest_neuron_repair_action(
570        &self,
571        activation_level: f32,
572        activations: &[f32],
573    ) -> NeuronRepairAction {
574        if activation_level < self.config.dead_neuron_threshold * 0.1 {
575            NeuronRepairAction::Reinitialize
576        } else if activation_level < self.config.dead_neuron_threshold * 0.5 {
577            let variance =
578                activations.iter().map(|&x| x.powi(2)).sum::<f32>() / activations.len() as f32;
579            if variance < 1e-10 {
580                NeuronRepairAction::AddNoise
581            } else {
582                NeuronRepairAction::AdjustLearningRate
583            }
584        } else {
585            NeuronRepairAction::ChangeActivationFunction
586        }
587    }
588
589    /// Perform correlation analysis
590    async fn perform_correlation_analysis(&self) -> Result<CorrelationAnalysis> {
591        // For simplification, we'll analyze correlations between input gradients
592        let gradient_vectors: Vec<&Vec<f32>> = self.input_gradients.values().collect();
593
594        if gradient_vectors.len() < 2 {
595            return Ok(CorrelationAnalysis {
596                correlation_matrix: Vec::new(),
597                significant_correlations: Vec::new(),
598                redundant_features: Vec::new(),
599                independent_features: Vec::new(),
600            });
601        }
602
603        let n = gradient_vectors.len();
604        let mut correlation_matrix = vec![vec![0.0; n]; n];
605        let mut significant_correlations = Vec::new();
606
607        // Compute correlation matrix
608        for i in 0..n {
609            for j in i..n {
610                let correlation =
611                    self.compute_correlation(gradient_vectors[i], gradient_vectors[j]);
612                correlation_matrix[i][j] = correlation;
613                correlation_matrix[j][i] = correlation;
614
615                if i != j && correlation.abs() > self.config.correlation_threshold {
616                    let correlation_type = if correlation.abs() > 0.8 {
617                        CorrelationType::Strong
618                    } else if correlation.abs() > 0.5 {
619                        CorrelationType::Moderate
620                    } else {
621                        CorrelationType::Weak
622                    };
623
624                    significant_correlations.push(CorrelationPair {
625                        feature_a: i,
626                        feature_b: j,
627                        correlation,
628                        p_value: 0.01, // Simplified p-value
629                        relationship_type: correlation_type,
630                    });
631                }
632            }
633        }
634
635        // Find redundant features (groups of highly correlated features)
636        let redundant_features = self.find_redundant_feature_groups(&correlation_matrix);
637
638        // Find independent features
639        let independent_features = self.find_independent_features(&correlation_matrix);
640
641        Ok(CorrelationAnalysis {
642            correlation_matrix,
643            significant_correlations,
644            redundant_features,
645            independent_features,
646        })
647    }
648
649    /// Compute Pearson correlation coefficient
650    fn compute_correlation(&self, x: &[f32], y: &[f32]) -> f32 {
651        if x.len() != y.len() || x.is_empty() {
652            return 0.0;
653        }
654
655        let n = x.len() as f32;
656        let mean_x = x.iter().sum::<f32>() / n;
657        let mean_y = y.iter().sum::<f32>() / n;
658
659        let numerator: f32 =
660            x.iter().zip(y.iter()).map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y)).sum();
661
662        let sum_sq_x: f32 = x.iter().map(|&xi| (xi - mean_x).powi(2)).sum();
663        let sum_sq_y: f32 = y.iter().map(|&yi| (yi - mean_y).powi(2)).sum();
664
665        let denominator = (sum_sq_x * sum_sq_y).sqrt();
666
667        if denominator > 1e-8 {
668            numerator / denominator
669        } else {
670            0.0
671        }
672    }
673
674    /// Find groups of redundant features
675    fn find_redundant_feature_groups(&self, correlation_matrix: &[Vec<f32>]) -> Vec<FeatureGroup> {
676        let mut groups = Vec::new();
677        let mut visited = HashSet::new();
678
679        for i in 0..correlation_matrix.len() {
680            if visited.contains(&i) {
681                continue;
682            }
683
684            let mut group = vec![i];
685            let mut group_correlations = Vec::new();
686
687            for j in (i + 1)..correlation_matrix.len() {
688                if correlation_matrix[i][j].abs() > 0.7 {
689                    group.push(j);
690                    group_correlations.push(correlation_matrix[i][j].abs());
691                    visited.insert(j);
692                }
693            }
694
695            if group.len() > 1 {
696                let average_correlation =
697                    group_correlations.iter().sum::<f32>() / group_correlations.len() as f32;
698                groups.push(FeatureGroup {
699                    features: group,
700                    average_correlation,
701                    group_importance: average_correlation, // Simplified importance
702                });
703            }
704
705            visited.insert(i);
706        }
707
708        groups
709    }
710
711    /// Find independent features
712    fn find_independent_features(&self, correlation_matrix: &[Vec<f32>]) -> Vec<usize> {
713        let mut independent = Vec::new();
714
715        for i in 0..correlation_matrix.len() {
716            let max_correlation = correlation_matrix[i]
717                .iter()
718                .enumerate()
719                .filter(|(j, _)| *j != i)
720                .map(|(_, &corr)| corr.abs())
721                .fold(0.0f32, |a, b| a.max(b));
722
723            if max_correlation < self.config.correlation_threshold {
724                independent.push(i);
725            }
726        }
727
728        independent
729    }
730
731    /// Generate behavior summary
732    fn generate_behavior_summary(&self, report: &mut BehaviorAnalysisReport) {
733        let total_neurons = report.activation_patterns.len();
734        let dead_neurons = report.dead_neurons.len();
735
736        report.behavior_summary.total_neurons_analyzed = total_neurons;
737        report.behavior_summary.dead_neuron_percentage = if total_neurons > 0 {
738            (dead_neurons as f32 / total_neurons as f32) * 100.0
739        } else {
740            0.0
741        };
742
743        if !report.activation_patterns.is_empty() {
744            report.behavior_summary.average_activation_sparsity = report
745                .activation_patterns
746                .iter()
747                .map(|p| p.activation_statistics.sparsity)
748                .sum::<f32>()
749                / report.activation_patterns.len() as f32;
750
751            report.behavior_summary.model_stability_score =
752                report.activation_patterns.iter().map(|p| p.stability_score).sum::<f32>()
753                    / report.activation_patterns.len() as f32;
754        }
755
756        // Simple entropy calculation for feature distribution
757        if !report.feature_importances.is_empty() {
758            let total_importance: f32 =
759                report.feature_importances.iter().map(|f| f.importance_score).sum();
760
761            if total_importance > 0.0 {
762                let entropy: f32 = report
763                    .feature_importances
764                    .iter()
765                    .map(|f| {
766                        let p = f.importance_score / total_importance;
767                        if p > 0.0 {
768                            -p * p.log2()
769                        } else {
770                            0.0
771                        }
772                    })
773                    .sum();
774                report.behavior_summary.feature_distribution_entropy = entropy;
775            }
776        }
777
778        // Overall interpretability score
779        report.behavior_summary.interpretability_score =
780            (report.behavior_summary.model_stability_score * 0.4
781                + (1.0 - report.behavior_summary.dead_neuron_percentage / 100.0) * 0.3
782                + (1.0 - report.behavior_summary.average_activation_sparsity) * 0.3)
783                .max(0.0)
784                .min(1.0);
785    }
786
787    /// Generate behavior recommendations
788    fn generate_recommendations(&self, report: &mut BehaviorAnalysisReport) {
789        // Dead neuron recommendations
790        if report.behavior_summary.dead_neuron_percentage > 20.0 {
791            report.recommendations.push(BehaviorRecommendation {
792                category: RecommendationCategory::Training,
793                priority: Priority::Critical,
794                description: format!("High percentage of dead neurons detected ({:.1}%)",
795                                   report.behavior_summary.dead_neuron_percentage),
796                implementation: "Consider reducing learning rate, changing initialization, or adding batch normalization".to_string(),
797                expected_impact: 0.8,
798            });
799        }
800
801        // Sparsity recommendations
802        if report.behavior_summary.average_activation_sparsity > 0.8 {
803            report.recommendations.push(BehaviorRecommendation {
804                category: RecommendationCategory::Architecture,
805                priority: Priority::High,
806                description: "Very sparse activations detected, model may be under-utilized".to_string(),
807                implementation: "Consider reducing model capacity or adjusting activation functions".to_string(),
808                expected_impact: 0.6,
809            });
810        }
811
812        // Stability recommendations
813        if report.behavior_summary.model_stability_score < 0.5 {
814            report.recommendations.push(BehaviorRecommendation {
815                category: RecommendationCategory::Training,
816                priority: Priority::High,
817                description: "Low model stability detected".to_string(),
818                implementation: "Consider adding regularization, reducing learning rate, or using gradient clipping".to_string(),
819                expected_impact: 0.7,
820            });
821        }
822
823        // Feature importance recommendations
824        if report.feature_importances.len() > 10 {
825            let top_features = &report.feature_importances[..5];
826            let bottom_features =
827                &report.feature_importances[report.feature_importances.len() - 5..];
828
829            let top_importance: f32 = top_features.iter().map(|f| f.importance_score).sum();
830            let bottom_importance: f32 = bottom_features.iter().map(|f| f.importance_score).sum();
831
832            if top_importance > bottom_importance * 10.0 {
833                report.recommendations.push(BehaviorRecommendation {
834                    category: RecommendationCategory::DataPreprocessing,
835                    priority: Priority::Medium,
836                    description: "Highly imbalanced feature importance detected".to_string(),
837                    implementation: "Consider feature selection or dimensionality reduction"
838                        .to_string(),
839                    expected_impact: 0.5,
840                });
841            }
842        }
843    }
844
845    /// Generate a comprehensive report
846    pub async fn generate_report(&self) -> Result<BehaviorAnalysisReport> {
847        let mut temp_analyzer = BehaviorAnalyzer {
848            config: self.config.clone(),
849            activation_history: self.activation_history.clone(),
850            input_gradients: self.input_gradients.clone(),
851            feature_attributions: self.feature_attributions.clone(),
852            analysis_cache: HashMap::new(),
853        };
854
855        temp_analyzer.analyze().await
856    }
857
858    /// Clear all recorded data
859    pub fn clear(&mut self) {
860        self.activation_history.clear();
861        self.input_gradients.clear();
862        self.feature_attributions.clear();
863        self.analysis_cache.clear();
864    }
865
866    /// Get summary of current analysis state
867    pub fn get_analysis_summary(&self) -> AnalysisSummary {
868        AnalysisSummary {
869            total_layers_tracked: self.activation_history.len(),
870            total_activation_samples: self
871                .activation_history
872                .values()
873                .map(|history| history.len())
874                .sum(),
875            total_inputs_tracked: self.input_gradients.len(),
876            analysis_coverage: if self.activation_history.is_empty() {
877                0.0
878            } else {
879                1.0 // Simplified coverage metric
880            },
881        }
882    }
883}
884
885/// Summary of analysis state
886#[derive(Debug, Clone, Serialize, Deserialize)]
887pub struct AnalysisSummary {
888    pub total_layers_tracked: usize,
889    pub total_activation_samples: usize,
890    pub total_inputs_tracked: usize,
891    pub analysis_coverage: f32,
892}