1use std::collections::HashMap;
8
9use super::types::{
10 ActivationHeatmap, AttentionVisualization, ClusteringResults, DriftInfo, HiddenStateAnalysis,
11 LayerActivationStats, LayerAnalysis, RepresentationStability, TemporalDynamics,
12 WeightDistribution,
13};
14
15#[derive(Debug)]
17pub struct LayerAnalyzer {
18 layer_activations: HashMap<String, Vec<LayerActivationStats>>,
20 config: LayerAnalysisConfig,
22 layer_states: HashMap<String, LayerState>,
24}
25
26#[derive(Debug, Clone)]
28pub struct LayerAnalysisConfig {
29 pub dead_neuron_threshold: f64,
31 pub saturated_neuron_threshold: f64,
33 pub max_activation_variance: f64,
35 pub min_health_score: f64,
37 pub history_length: usize,
39}
40
41impl Default for LayerAnalysisConfig {
42 fn default() -> Self {
43 Self {
44 dead_neuron_threshold: 0.1,
45 saturated_neuron_threshold: 0.1,
46 max_activation_variance: 2.0,
47 min_health_score: 0.7,
48 history_length: 100,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55struct LayerState {
56 health_scores: Vec<f64>,
58 #[allow(dead_code)]
60 detected_issues: Vec<String>,
61 last_analysis_step: usize,
63}
64
65impl Default for LayerState {
66 fn default() -> Self {
67 Self {
68 health_scores: Vec::new(),
69 detected_issues: Vec::new(),
70 last_analysis_step: 0,
71 }
72 }
73}
74
75impl LayerAnalyzer {
76 pub fn new() -> Self {
78 Self {
79 layer_activations: HashMap::new(),
80 config: LayerAnalysisConfig::default(),
81 layer_states: HashMap::new(),
82 }
83 }
84
85 pub fn with_config(config: LayerAnalysisConfig) -> Self {
87 Self {
88 layer_activations: HashMap::new(),
89 config,
90 layer_states: HashMap::new(),
91 }
92 }
93
94 pub fn record_layer_activations(&mut self, layer_name: &str, stats: LayerActivationStats) {
96 let health_score = self.calculate_layer_health_score(&stats);
98
99 let layer_stats =
100 self.layer_activations.entry(layer_name.to_string()).or_insert_with(Vec::new);
101 layer_stats.push(stats);
102
103 if layer_stats.len() > self.config.history_length {
105 layer_stats.remove(0);
106 }
107
108 let layer_state = self
110 .layer_states
111 .entry(layer_name.to_string())
112 .or_insert_with(LayerState::default);
113 layer_state.health_scores.push(health_score);
114
115 if layer_state.health_scores.len() > 50 {
116 layer_state.health_scores.remove(0);
117 }
118
119 layer_state.last_analysis_step += 1;
120 }
121
122 pub fn record_layer_stats(&mut self, stats: LayerActivationStats) {
124 let layer_name = stats.layer_name.clone();
125 self.record_layer_activations(&layer_name, stats);
126 }
127
128 pub fn get_layer_activations(&self, layer_name: &str) -> Option<&[LayerActivationStats]> {
130 self.layer_activations.get(layer_name).map(|v| v.as_slice())
131 }
132
133 pub fn perform_layer_by_layer_analysis(&self) -> Vec<LayerAnalysis> {
135 let mut analyses = Vec::new();
136
137 for (layer_name, stats_history) in &self.layer_activations {
138 if let Some(latest_stats) = stats_history.last() {
139 let analysis = self.analyze_single_layer(layer_name, latest_stats, stats_history);
140 analyses.push(analysis);
141 }
142 }
143
144 analyses.sort_by(|a, b| {
145 a.health_score.partial_cmp(&b.health_score).unwrap_or(std::cmp::Ordering::Equal)
146 });
147 analyses
148 }
149
150 pub fn analyze_single_layer(
152 &self,
153 layer_name: &str,
154 current_stats: &LayerActivationStats,
155 stats_history: &[LayerActivationStats],
156 ) -> LayerAnalysis {
157 let layer_type = self.infer_layer_type(layer_name);
158 let health_score = self.calculate_layer_health_score(current_stats);
159 let issues = self.identify_layer_issues(current_stats, stats_history);
160 let recommendations = self.generate_layer_recommendations(&issues, &layer_type);
161 let activation_summary = self.generate_activation_summary(current_stats);
162
163 LayerAnalysis {
164 layer_name: layer_name.to_string(),
165 layer_type,
166 health_score,
167 issues,
168 recommendations,
169 activation_summary,
170 }
171 }
172
173 pub fn calculate_layer_health_score(&self, stats: &LayerActivationStats) -> f64 {
175 let mut score = 1.0;
176
177 if stats.dead_neurons_ratio > self.config.dead_neuron_threshold {
179 score -= stats.dead_neurons_ratio * 0.5;
180 }
181
182 if stats.saturated_neurons_ratio > self.config.saturated_neuron_threshold {
184 score -= stats.saturated_neurons_ratio * 0.3;
185 }
186
187 let activation_range = stats.max_activation - stats.min_activation;
189 if activation_range > 10.0 {
190 score -= 0.2;
191 }
192
193 if stats.std_activation > self.config.max_activation_variance {
195 score -= 0.2;
196 }
197
198 if stats.sparsity > 0.1 && stats.sparsity < 0.8 {
200 score += 0.1;
201 }
202
203 score.max(0.0).min(1.0)
204 }
205
206 pub fn identify_layer_issues(
208 &self,
209 current_stats: &LayerActivationStats,
210 stats_history: &[LayerActivationStats],
211 ) -> Vec<String> {
212 let mut issues = Vec::new();
213
214 if current_stats.dead_neurons_ratio > self.config.dead_neuron_threshold {
216 issues.push(format!(
217 "High dead neuron ratio: {:.1}%",
218 current_stats.dead_neurons_ratio * 100.0
219 ));
220 }
221
222 if current_stats.saturated_neurons_ratio > self.config.saturated_neuron_threshold {
224 issues.push(format!(
225 "High saturated neuron ratio: {:.1}%",
226 current_stats.saturated_neurons_ratio * 100.0
227 ));
228 }
229
230 if current_stats.max_activation - current_stats.min_activation > 100.0 {
232 issues.push("Extremely wide activation range detected".to_string());
233 }
234
235 if current_stats.std_activation > self.config.max_activation_variance {
237 issues.push("High activation variance detected".to_string());
238 }
239
240 if stats_history.len() > 5 {
242 let variance_trend = self.analyze_variance_trend(stats_history);
243 if variance_trend > 0.1 {
244 issues.push("Increasing activation variance over time".to_string());
245 }
246 }
247
248 if current_stats.mean_activation.abs() < 1e-6 {
250 issues.push("Near-zero mean activation detected".to_string());
251 }
252
253 issues
254 }
255
256 pub fn generate_layer_recommendations(
258 &self,
259 issues: &[String],
260 layer_type: &str,
261 ) -> Vec<String> {
262 let mut recommendations = Vec::new();
263
264 for issue in issues {
265 if issue.contains("dead neuron") {
266 match layer_type {
267 "Linear" => recommendations
268 .push("Consider using LeakyReLU or ELU activation".to_string()),
269 "Convolutional" => recommendations.push(
270 "Consider batch normalization or different initialization".to_string(),
271 ),
272 _ => recommendations.push(
273 "Consider different activation function or initialization".to_string(),
274 ),
275 }
276 }
277
278 if issue.contains("saturated neuron") {
279 recommendations
280 .push("Consider gradient clipping or learning rate reduction".to_string());
281 recommendations.push("Consider batch normalization".to_string());
282 }
283
284 if issue.contains("activation range") {
285 recommendations.push("Consider activation clipping or normalization".to_string());
286 }
287
288 if issue.contains("variance") {
289 recommendations.push("Consider weight initialization adjustment".to_string());
290 recommendations.push("Consider adding regularization".to_string());
291 }
292
293 if issue.contains("zero activation") {
294 recommendations
295 .push("Check weight initialization and input preprocessing".to_string());
296 }
297 }
298
299 recommendations.dedup();
300 recommendations
301 }
302
303 pub fn analyze_weight_distributions(&self) -> HashMap<String, WeightDistribution> {
305 let mut distributions = HashMap::new();
306
307 for layer_name in self.layer_activations.keys() {
308 let distribution = self.analyze_layer_weight_distribution(layer_name);
309 distributions.insert(layer_name.clone(), distribution);
310 }
311
312 distributions
313 }
314
315 pub fn generate_activation_heatmaps(&self) -> HashMap<String, ActivationHeatmap> {
317 let mut heatmaps = HashMap::new();
318
319 for (layer_name, stats_history) in &self.layer_activations {
320 if let Some(latest_stats) = stats_history.last() {
321 let heatmap = self.create_activation_heatmap(layer_name, latest_stats);
322 heatmaps.insert(layer_name.clone(), heatmap);
323 }
324 }
325
326 heatmaps
327 }
328
329 pub fn generate_attention_visualizations(&self) -> HashMap<String, AttentionVisualization> {
331 let mut visualizations = HashMap::new();
332
333 for layer_name in self.layer_activations.keys() {
334 if self.infer_layer_type(layer_name) == "Attention" {
335 let visualization = self.create_attention_visualization(layer_name);
336 visualizations.insert(layer_name.clone(), visualization);
337 }
338 }
339
340 visualizations
341 }
342
343 pub fn analyze_hidden_states(&self) -> HashMap<String, HiddenStateAnalysis> {
345 let mut analyses = HashMap::new();
346
347 for layer_name in self.layer_activations.keys() {
348 let analysis = self.analyze_layer_hidden_states(layer_name);
349 analyses.insert(layer_name.clone(), analysis);
350 }
351
352 analyses
353 }
354
355 fn infer_layer_type(&self, layer_name: &str) -> String {
358 let name_lower = layer_name.to_lowercase();
359
360 if name_lower.contains("attention") || name_lower.contains("attn") {
361 "Attention".to_string()
362 } else if name_lower.contains("linear")
363 || name_lower.contains("dense")
364 || name_lower.contains("fc")
365 {
366 "Linear".to_string()
367 } else if name_lower.contains("conv") {
368 "Convolutional".to_string()
369 } else if name_lower.contains("norm")
370 || name_lower.contains("bn")
371 || name_lower.contains("ln")
372 {
373 "Normalization".to_string()
374 } else if name_lower.contains("dropout") {
375 "Dropout".to_string()
376 } else if name_lower.contains("embed") {
377 "Embedding".to_string()
378 } else {
379 "Unknown".to_string()
380 }
381 }
382
383 fn generate_activation_summary(&self, stats: &LayerActivationStats) -> String {
384 format!(
385 "Mean: {:.3}, Std: {:.3}, Range: [{:.3}, {:.3}], Dead: {:.1}%, Saturated: {:.1}%, Sparsity: {:.1}%",
386 stats.mean_activation,
387 stats.std_activation,
388 stats.min_activation,
389 stats.max_activation,
390 stats.dead_neurons_ratio * 100.0,
391 stats.saturated_neurons_ratio * 100.0,
392 stats.sparsity * 100.0
393 )
394 }
395
396 fn analyze_variance_trend(&self, stats_history: &[LayerActivationStats]) -> f64 {
397 if stats_history.len() < 2 {
398 return 0.0;
399 }
400
401 let variances: Vec<f64> = stats_history.iter().map(|s| s.std_activation.powi(2)).collect();
402 self.calculate_trend(&variances)
403 }
404
405 fn calculate_trend(&self, values: &[f64]) -> f64 {
406 if values.len() < 2 {
407 return 0.0;
408 }
409
410 let n = values.len() as f64;
411 let x_mean = (n - 1.0) / 2.0;
412 let y_mean = values.iter().sum::<f64>() / n;
413
414 let mut numerator = 0.0;
415 let mut denominator = 0.0;
416
417 for (i, &y) in values.iter().enumerate() {
418 let x = i as f64;
419 numerator += (x - x_mean) * (y - y_mean);
420 denominator += (x - x_mean).powi(2);
421 }
422
423 if denominator == 0.0 {
424 0.0
425 } else {
426 numerator / denominator
427 }
428 }
429
430 fn analyze_layer_weight_distribution(&self, layer_name: &str) -> WeightDistribution {
431 use scirs2_core::random::*; let mut rng = thread_rng();
433
434 let layer_type = self.infer_layer_type(layer_name);
436 let (mean, std_dev) = match layer_type.as_str() {
437 "Linear" => (rng.gen_range(-0.1..0.1), rng.gen_range(0.1..0.5)),
438 "Convolutional" => (rng.gen_range(-0.05..0.05), rng.gen_range(0.05..0.3)),
439 "Attention" => (rng.gen_range(-0.02..0.02), rng.gen_range(0.02..0.2)),
440 _ => (rng.gen_range(-0.1..0.1), rng.gen_range(0.1..0.4)),
441 };
442
443 let min = mean - 3.0 * std_dev;
444 let max = mean + 3.0 * std_dev;
445 let sparsity = rng.gen_range(0.0..0.3);
446
447 WeightDistribution {
448 mean,
449 std_dev,
450 min,
451 max,
452 sparsity,
453 distribution_shape: "Normal".to_string(),
454 }
455 }
456
457 fn create_activation_heatmap(
458 &self,
459 layer_name: &str,
460 stats: &LayerActivationStats,
461 ) -> ActivationHeatmap {
462 use scirs2_core::random::*; let mut rng = thread_rng();
464
465 let (height, width) = if stats.output_shape.len() >= 2 {
467 (stats.output_shape[0].min(64), stats.output_shape[1].min(64))
468 } else {
469 (32, 32)
470 };
471
472 let data: Vec<Vec<f64>> = (0..height)
473 .map(|_| {
474 (0..width)
475 .map(|_| rng.gen_range(stats.min_activation..stats.max_activation))
476 .collect()
477 })
478 .collect();
479
480 ActivationHeatmap {
481 data,
482 dimensions: (height, width),
483 value_range: (stats.min_activation, stats.max_activation),
484 interpretation: format!(
485 "Activation pattern for {} layer",
486 self.infer_layer_type(layer_name)
487 ),
488 }
489 }
490
491 fn create_attention_visualization(&self, _layer_name: &str) -> AttentionVisualization {
492 use scirs2_core::random::*; let mut rng = thread_rng();
494
495 let seq_length = rng.gen_range(10..50);
496 let attention_weights: Vec<Vec<f64>> = (0..seq_length)
497 .map(|_| (0..seq_length).map(|_| rng.gen_range(0.0..1.0)).collect())
498 .collect();
499
500 let input_tokens: Vec<String> = (0..seq_length).map(|i| format!("token_{}", i)).collect();
501
502 let output_tokens = input_tokens.clone();
503
504 let patterns = vec![
505 "Self-attention pattern detected".to_string(),
506 "Local attention focused".to_string(),
507 "Global attention pattern".to_string(),
508 ];
509
510 AttentionVisualization {
511 attention_weights,
512 input_tokens,
513 output_tokens,
514 patterns,
515 }
516 }
517
518 fn analyze_layer_hidden_states(&self, layer_name: &str) -> HiddenStateAnalysis {
519 use scirs2_core::random::*; let _rng = thread_rng();
521
522 let dimensionality = self.get_hidden_dimensions(layer_name);
523 let information_content = self.compute_information_content(layer_name);
524 let clustering_results = self.perform_clustering_analysis(layer_name);
525 let temporal_dynamics = self.analyze_temporal_dynamics(layer_name);
526 let representation_stability = self.assess_representation_stability(layer_name);
527
528 HiddenStateAnalysis {
529 dimensionality,
530 information_content,
531 clustering_results,
532 temporal_dynamics,
533 representation_stability,
534 }
535 }
536
537 fn get_hidden_dimensions(&self, layer_name: &str) -> usize {
538 if let Some(stats_history) = self.layer_activations.get(layer_name) {
539 if let Some(latest_stats) = stats_history.last() {
540 return latest_stats.output_shape.iter().product();
541 }
542 }
543 512 }
545
546 fn compute_information_content(&self, layer_name: &str) -> f64 {
547 use scirs2_core::random::*; let mut rng = thread_rng();
549
550 let layer_type = self.infer_layer_type(layer_name);
551 match layer_type.as_str() {
552 "Attention" => rng.gen_range(0.6..0.9),
553 "Linear" => rng.gen_range(0.4..0.7),
554 "Convolutional" => rng.gen_range(0.3..0.6),
555 _ => rng.gen_range(0.4..0.7),
556 }
557 }
558
559 fn perform_clustering_analysis(&self, layer_name: &str) -> ClusteringResults {
560 use scirs2_core::random::*; let mut rng = thread_rng();
562
563 let hidden_dims = self.get_hidden_dimensions(layer_name);
564 let num_clusters = rng.gen_range(5..20);
565
566 let cluster_centers: Vec<Vec<f64>> = (0..num_clusters)
567 .map(|_| (0..hidden_dims.min(10)).map(|_| rng.gen_range(-1.0..1.0)).collect())
568 .collect();
569
570 let cluster_assignments: Vec<usize> =
571 (0..100).map(|_| rng.gen_range(0..num_clusters)).collect();
572
573 ClusteringResults {
574 num_clusters,
575 cluster_centers,
576 cluster_assignments,
577 silhouette_score: rng.gen_range(0.2..0.8),
578 inertia: rng.gen_range(100.0..1000.0),
579 }
580 }
581
582 fn analyze_temporal_dynamics(&self, _layer_name: &str) -> TemporalDynamics {
583 use scirs2_core::random::*; let mut rng = thread_rng();
585
586 let consistency = rng.gen_range(0.5..0.9);
587 let change_rate = rng.gen_range(0.01..0.1);
588
589 let num_windows = rng.gen_range(3..8);
590 let stability_windows: Vec<(usize, usize)> = (0..num_windows)
591 .map(|i| {
592 let start = i * 100;
593 let end = start + rng.gen_range(50..150);
594 (start, end)
595 })
596 .collect();
597
598 let drift_detected = rng.gen_bool(0.2);
599 let drift_info = DriftInfo {
600 drift_detected,
601 drift_magnitude: if drift_detected { rng.gen_range(0.1..0.5) } else { 0.0 },
602 drift_direction: if drift_detected {
603 ["increasing", "decreasing", "oscillating"][rng.gen_range(0..3)].to_string()
604 } else {
605 "stable".to_string()
606 },
607 onset_step: if drift_detected { Some(rng.gen_range(100..1000)) } else { None },
608 };
609
610 TemporalDynamics {
611 temporal_consistency: consistency,
612 change_rate,
613 stability_windows,
614 drift_detection: drift_info,
615 }
616 }
617
618 fn assess_representation_stability(&self, layer_name: &str) -> RepresentationStability {
619 use scirs2_core::random::*; let mut rng = thread_rng();
621
622 let layer_type = self.infer_layer_type(layer_name);
623
624 let stability_score = match layer_type.as_str() {
625 "Normalization" => rng.gen_range(0.8..0.95),
626 "Attention" => rng.gen_range(0.6..0.85),
627 "Linear" => rng.gen_range(0.5..0.8),
628 _ => rng.gen_range(0.4..0.7),
629 };
630
631 RepresentationStability {
632 stability_score,
633 variance_across_batches: rng.gen_range(0.01..0.1),
634 consistency_measure: rng.gen_range(0.6..0.9),
635 robustness_to_noise: rng.gen_range(0.3..0.8),
636 }
637 }
638
639 pub fn clear(&mut self) {
641 self.layer_activations.clear();
642 self.layer_states.clear();
643 }
644}
645
646impl Default for LayerAnalyzer {
647 fn default() -> Self {
648 Self::new()
649 }
650}
651
652#[cfg(test)]
653mod tests {
654 use super::*;
655
656 fn create_test_layer_stats(layer_name: &str) -> LayerActivationStats {
657 LayerActivationStats {
658 layer_name: layer_name.to_string(),
659 mean_activation: 0.5,
660 std_activation: 0.2,
661 min_activation: 0.0,
662 max_activation: 1.0,
663 dead_neurons_ratio: 0.05,
664 saturated_neurons_ratio: 0.03,
665 sparsity: 0.3,
666 output_shape: vec![128, 256],
667 }
668 }
669
670 #[test]
671 fn test_layer_analyzer_creation() {
672 let analyzer = LayerAnalyzer::new();
673 assert_eq!(analyzer.layer_activations.len(), 0);
674 }
675
676 #[test]
677 fn test_record_layer_activations() {
678 let mut analyzer = LayerAnalyzer::new();
679 let stats = create_test_layer_stats("test_layer");
680
681 analyzer.record_layer_activations("test_layer", stats);
682 assert_eq!(analyzer.layer_activations.len(), 1);
683 assert!(analyzer.layer_activations.contains_key("test_layer"));
684 }
685
686 #[test]
687 fn test_layer_health_score_calculation() {
688 let analyzer = LayerAnalyzer::new();
689 let stats = create_test_layer_stats("test_layer");
690
691 let health_score = analyzer.calculate_layer_health_score(&stats);
692 assert!(health_score > 0.0 && health_score <= 1.0);
693 }
694
695 #[test]
696 fn test_layer_type_inference() {
697 let analyzer = LayerAnalyzer::new();
698
699 assert_eq!(analyzer.infer_layer_type("attention_layer"), "Attention");
700 assert_eq!(analyzer.infer_layer_type("linear_projection"), "Linear");
701 assert_eq!(analyzer.infer_layer_type("conv2d_layer"), "Convolutional");
702 assert_eq!(analyzer.infer_layer_type("batch_norm"), "Normalization");
703 }
704
705 #[test]
706 fn test_issue_identification() {
707 let analyzer = LayerAnalyzer::new();
708 let mut stats = create_test_layer_stats("test_layer");
709 stats.dead_neurons_ratio = 0.2; let issues = analyzer.identify_layer_issues(&stats, &[]);
712 assert!(!issues.is_empty());
713 assert!(issues[0].contains("dead neuron"));
714 }
715
716 #[test]
717 fn test_layer_analysis() {
718 let analyzer = LayerAnalyzer::new();
719 let stats = create_test_layer_stats("attention_layer");
720 let history = vec![stats.clone()];
721
722 let analysis = analyzer.analyze_single_layer("attention_layer", &stats, &history);
723 assert_eq!(analysis.layer_name, "attention_layer");
724 assert_eq!(analysis.layer_type, "Attention");
725 assert!(analysis.health_score > 0.0);
726 }
727}