trustformers_debug/gradient_debugger/
conflict_analysis.rs

1//! Gradient Conflict Analysis Between Layers
2//!
3//! This module provides comprehensive analysis of gradient conflicts between layers,
4//! including conflict detection, classification, and mitigation strategy generation.
5
6use super::types::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Analysis of gradient conflicts between layers
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct GradientConflictAnalysis {
13    pub total_conflicts: usize,
14    pub conflicts: Vec<GradientConflict>,
15    pub overall_conflict_level: ConflictLevel,
16    pub mitigation_strategies: Vec<ConflictMitigationStrategy>,
17}
18
19/// Individual gradient conflict between two layers
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct GradientConflict {
22    pub layer1: String,
23    pub layer2: String,
24    pub conflict_score: f64,
25    pub conflict_type: ConflictType,
26    pub recommendations: Vec<String>,
27}
28
29/// Types of gradient conflicts
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub enum ConflictType {
32    None,
33    Mild,
34    Moderate,
35    Severe,
36}
37
38/// Overall level of conflicts in the network
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub enum ConflictLevel {
41    Low,
42    Medium,
43    High,
44    Critical,
45}
46
47/// Strategy for mitigating gradient conflicts
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ConflictMitigationStrategy {
50    pub strategy_name: String,
51    pub description: String,
52    pub target_conflicts: Vec<String>,
53    pub effectiveness: f64,
54    pub implementation_complexity: MitigationComplexity,
55    pub expected_outcome: String,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum MitigationComplexity {
60    Simple,
61    Moderate,
62    Complex,
63    RequiresArchitectureChange,
64}
65
66/// Gradient conflict analyzer
67#[derive(Debug)]
68pub struct GradientConflictAnalyzer {
69    conflict_threshold: f64,
70    analysis_window: usize,
71}
72
73impl Default for GradientConflictAnalyzer {
74    fn default() -> Self {
75        Self {
76            conflict_threshold: 0.5,
77            analysis_window: 10,
78        }
79    }
80}
81
82impl GradientConflictAnalyzer {
83    pub fn new(threshold: f64, window: usize) -> Self {
84        Self {
85            conflict_threshold: threshold,
86            analysis_window: window,
87        }
88    }
89
90    pub fn analyze_conflicts(
91        &self,
92        gradient_histories: &HashMap<String, GradientHistory>,
93    ) -> GradientConflictAnalysis {
94        let mut conflicts = Vec::new();
95        let mut layer_gradients: Vec<(String, Vec<f64>)> = Vec::new();
96
97        // Collect recent gradient norms for each layer
98        for (layer_name, history) in gradient_histories {
99            if let Some(recent_gradients) = self.get_recent_gradients(history, self.analysis_window)
100            {
101                layer_gradients.push((layer_name.clone(), recent_gradients));
102            }
103        }
104
105        // Analyze conflicts between pairs of layers
106        for i in 0..layer_gradients.len() {
107            for j in (i + 1)..layer_gradients.len() {
108                let (layer1_name, layer1_grads) = &layer_gradients[i];
109                let (layer2_name, layer2_grads) = &layer_gradients[j];
110
111                let conflict_score = self.compute_gradient_conflict(layer1_grads, layer2_grads);
112
113                if conflict_score > self.conflict_threshold {
114                    conflicts.push(GradientConflict {
115                        layer1: layer1_name.clone(),
116                        layer2: layer2_name.clone(),
117                        conflict_score,
118                        conflict_type: self.classify_conflict_type(conflict_score),
119                        recommendations: self.get_conflict_recommendations(conflict_score),
120                    });
121                }
122            }
123        }
124
125        let overall_conflict_level = self.compute_overall_conflict_level(&conflicts);
126        let mitigation_strategies = self.generate_conflict_mitigation_strategies(&conflicts);
127
128        GradientConflictAnalysis {
129            total_conflicts: conflicts.len(),
130            conflicts,
131            overall_conflict_level,
132            mitigation_strategies,
133        }
134    }
135
136    fn get_recent_gradients(&self, history: &GradientHistory, count: usize) -> Option<Vec<f64>> {
137        if history.gradient_norms.len() < count {
138            return None;
139        }
140
141        Some(history.gradient_norms.iter().rev().take(count).cloned().collect())
142    }
143
144    fn compute_gradient_conflict(&self, grads1: &[f64], grads2: &[f64]) -> f64 {
145        if grads1.len() != grads2.len() || grads1.is_empty() {
146            return 0.0;
147        }
148
149        // Compute cosine similarity between gradient patterns
150        let dot_product: f64 = grads1.iter().zip(grads2.iter()).map(|(a, b)| a * b).sum();
151        let norm1: f64 = grads1.iter().map(|x| x * x).sum::<f64>().sqrt();
152        let norm2: f64 = grads2.iter().map(|x| x * x).sum::<f64>().sqrt();
153
154        if norm1 == 0.0 || norm2 == 0.0 {
155            return 1.0; // Maximum conflict if one has zero gradients
156        }
157
158        let cosine_similarity = dot_product / (norm1 * norm2);
159
160        // Convert to conflict score (0 = no conflict, 1 = maximum conflict)
161        (1.0 - cosine_similarity.abs()).max(0.0)
162    }
163
164    fn classify_conflict_type(&self, conflict_score: f64) -> ConflictType {
165        match conflict_score {
166            x if x > 0.8 => ConflictType::Severe,
167            x if x > 0.6 => ConflictType::Moderate,
168            x if x > 0.3 => ConflictType::Mild,
169            _ => ConflictType::None,
170        }
171    }
172
173    fn get_conflict_recommendations(&self, conflict_score: f64) -> Vec<String> {
174        let mut recommendations = Vec::new();
175
176        match conflict_score {
177            x if x > 0.8 => {
178                recommendations.push("Critical gradient conflict detected".to_string());
179                recommendations.push("Consider gradient clipping or normalization".to_string());
180                recommendations.push("Review learning rates for affected layers".to_string());
181                recommendations.push("Consider architectural changes".to_string());
182            },
183            x if x > 0.6 => {
184                recommendations.push("Moderate gradient conflict detected".to_string());
185                recommendations.push("Consider adjusting learning rates".to_string());
186                recommendations.push("Monitor gradient flow patterns".to_string());
187            },
188            x if x > 0.3 => {
189                recommendations.push("Mild gradient conflict detected".to_string());
190                recommendations.push("Continue monitoring conflict patterns".to_string());
191            },
192            _ => {
193                recommendations.push("No significant conflict detected".to_string());
194            },
195        }
196
197        recommendations
198    }
199
200    fn compute_overall_conflict_level(&self, conflicts: &[GradientConflict]) -> ConflictLevel {
201        if conflicts.is_empty() {
202            return ConflictLevel::Low;
203        }
204
205        let severe_conflicts = conflicts
206            .iter()
207            .filter(|c| matches!(c.conflict_type, ConflictType::Severe))
208            .count();
209        let moderate_conflicts = conflicts
210            .iter()
211            .filter(|c| matches!(c.conflict_type, ConflictType::Moderate))
212            .count();
213
214        let total_layers_with_conflicts = self.count_layers_with_conflicts(conflicts);
215
216        if severe_conflicts > 0 || total_layers_with_conflicts > 10 {
217            ConflictLevel::Critical
218        } else if moderate_conflicts > 3 || total_layers_with_conflicts > 5 {
219            ConflictLevel::High
220        } else if moderate_conflicts > 0 || total_layers_with_conflicts > 2 {
221            ConflictLevel::Medium
222        } else {
223            ConflictLevel::Low
224        }
225    }
226
227    fn count_layers_with_conflicts(&self, conflicts: &[GradientConflict]) -> usize {
228        let mut layers = std::collections::HashSet::new();
229        for conflict in conflicts {
230            layers.insert(&conflict.layer1);
231            layers.insert(&conflict.layer2);
232        }
233        layers.len()
234    }
235
236    fn generate_conflict_mitigation_strategies(
237        &self,
238        conflicts: &[GradientConflict],
239    ) -> Vec<ConflictMitigationStrategy> {
240        let mut strategies = Vec::new();
241
242        if conflicts.is_empty() {
243            return strategies;
244        }
245
246        // Gradient clipping strategy
247        let severe_conflicts = conflicts
248            .iter()
249            .filter(|c| matches!(c.conflict_type, ConflictType::Severe))
250            .count();
251        if severe_conflicts > 0 {
252            strategies.push(ConflictMitigationStrategy {
253                strategy_name: "Gradient Clipping".to_string(),
254                description: "Apply gradient clipping to prevent extreme gradient values"
255                    .to_string(),
256                target_conflicts: conflicts
257                    .iter()
258                    .filter(|c| matches!(c.conflict_type, ConflictType::Severe))
259                    .map(|c| format!("{}-{}", c.layer1, c.layer2))
260                    .collect(),
261                effectiveness: 0.8,
262                implementation_complexity: MitigationComplexity::Simple,
263                expected_outcome: "Reduced gradient magnitude conflicts".to_string(),
264            });
265        }
266
267        // Learning rate adjustment strategy
268        if conflicts.len() > 2 {
269            strategies.push(ConflictMitigationStrategy {
270                strategy_name: "Adaptive Learning Rates".to_string(),
271                description: "Use layer-specific learning rates to balance gradient flows"
272                    .to_string(),
273                target_conflicts: conflicts
274                    .iter()
275                    .map(|c| format!("{}-{}", c.layer1, c.layer2))
276                    .collect(),
277                effectiveness: 0.7,
278                implementation_complexity: MitigationComplexity::Moderate,
279                expected_outcome: "Better gradient balance across layers".to_string(),
280            });
281        }
282
283        // Normalization strategy
284        let high_conflict_count = conflicts
285            .iter()
286            .filter(|c| {
287                matches!(
288                    c.conflict_type,
289                    ConflictType::Severe | ConflictType::Moderate
290                )
291            })
292            .count();
293
294        if high_conflict_count > 1 {
295            strategies.push(ConflictMitigationStrategy {
296                strategy_name: "Gradient Normalization".to_string(),
297                description: "Normalize gradients to reduce scale conflicts".to_string(),
298                target_conflicts: conflicts
299                    .iter()
300                    .filter(|c| {
301                        matches!(
302                            c.conflict_type,
303                            ConflictType::Severe | ConflictType::Moderate
304                        )
305                    })
306                    .map(|c| format!("{}-{}", c.layer1, c.layer2))
307                    .collect(),
308                effectiveness: 0.6,
309                implementation_complexity: MitigationComplexity::Simple,
310                expected_outcome: "More consistent gradient scales".to_string(),
311            });
312        }
313
314        // Architecture modification strategy for critical conflicts
315        if severe_conflicts > 3 {
316            strategies.push(ConflictMitigationStrategy {
317                strategy_name: "Architecture Modification".to_string(),
318                description: "Consider residual connections or attention mechanisms".to_string(),
319                target_conflicts: conflicts
320                    .iter()
321                    .filter(|c| matches!(c.conflict_type, ConflictType::Severe))
322                    .map(|c| format!("{}-{}", c.layer1, c.layer2))
323                    .collect(),
324                effectiveness: 0.9,
325                implementation_complexity: MitigationComplexity::RequiresArchitectureChange,
326                expected_outcome: "Fundamental improvement in gradient flow".to_string(),
327            });
328        }
329
330        strategies
331    }
332
333    pub fn generate_conflict_report(&self, analysis: &GradientConflictAnalysis) -> ConflictReport {
334        let mut layer_conflict_counts = HashMap::new();
335        #[allow(dead_code)]
336        #[allow(unused_assignments)]
337        let mut most_problematic_pairs = Vec::new();
338
339        // Count conflicts per layer
340        for conflict in &analysis.conflicts {
341            *layer_conflict_counts.entry(conflict.layer1.clone()).or_insert(0) += 1;
342            *layer_conflict_counts.entry(conflict.layer2.clone()).or_insert(0) += 1;
343        }
344
345        // Find most problematic layer pairs
346        let mut sorted_conflicts = analysis.conflicts.clone();
347        sorted_conflicts.sort_by(|a, b| b.conflict_score.partial_cmp(&a.conflict_score).unwrap());
348        most_problematic_pairs = sorted_conflicts.into_iter().take(5).collect();
349
350        // Find most problematic layers
351        let mut layer_scores: Vec<(String, usize)> = layer_conflict_counts.into_iter().collect();
352        layer_scores.sort_by(|a, b| b.1.cmp(&a.1));
353        let most_problematic_layers: Vec<String> =
354            layer_scores.into_iter().take(5).map(|(name, _)| name).collect();
355
356        ConflictReport {
357            total_conflicts: analysis.total_conflicts,
358            overall_level: analysis.overall_conflict_level.clone(),
359            most_problematic_layers,
360            most_problematic_pairs,
361            recommended_strategies: analysis.mitigation_strategies.clone(),
362            summary: self.generate_conflict_summary(analysis),
363        }
364    }
365
366    fn generate_conflict_summary(&self, analysis: &GradientConflictAnalysis) -> String {
367        match analysis.overall_conflict_level {
368            ConflictLevel::Critical => {
369                format!("Critical gradient conflicts detected ({} total). Immediate action required to stabilize training.", analysis.total_conflicts)
370            },
371            ConflictLevel::High => {
372                format!("High level of gradient conflicts ({} total). Consider implementing mitigation strategies.", analysis.total_conflicts)
373            },
374            ConflictLevel::Medium => {
375                format!("Moderate gradient conflicts detected ({} total). Monitor and consider optimization.", analysis.total_conflicts)
376            },
377            ConflictLevel::Low => {
378                format!(
379                    "Low conflict level ({} total). Gradient flow appears stable.",
380                    analysis.total_conflicts
381                )
382            },
383        }
384    }
385}
386
387/// Comprehensive conflict analysis report
388#[derive(Debug, Clone, Serialize, Deserialize)]
389pub struct ConflictReport {
390    pub total_conflicts: usize,
391    pub overall_level: ConflictLevel,
392    pub most_problematic_layers: Vec<String>,
393    pub most_problematic_pairs: Vec<GradientConflict>,
394    pub recommended_strategies: Vec<ConflictMitigationStrategy>,
395    pub summary: String,
396}