trustformers_debug/gradient_debugger/
conflict_analysis.rs1use super::types::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct GradientConflictAnalysis {
13 pub total_conflicts: usize,
14 pub conflicts: Vec<GradientConflict>,
15 pub overall_conflict_level: ConflictLevel,
16 pub mitigation_strategies: Vec<ConflictMitigationStrategy>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct GradientConflict {
22 pub layer1: String,
23 pub layer2: String,
24 pub conflict_score: f64,
25 pub conflict_type: ConflictType,
26 pub recommendations: Vec<String>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub enum ConflictType {
32 None,
33 Mild,
34 Moderate,
35 Severe,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub enum ConflictLevel {
41 Low,
42 Medium,
43 High,
44 Critical,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ConflictMitigationStrategy {
50 pub strategy_name: String,
51 pub description: String,
52 pub target_conflicts: Vec<String>,
53 pub effectiveness: f64,
54 pub implementation_complexity: MitigationComplexity,
55 pub expected_outcome: String,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub enum MitigationComplexity {
60 Simple,
61 Moderate,
62 Complex,
63 RequiresArchitectureChange,
64}
65
66#[derive(Debug)]
68pub struct GradientConflictAnalyzer {
69 conflict_threshold: f64,
70 analysis_window: usize,
71}
72
73impl Default for GradientConflictAnalyzer {
74 fn default() -> Self {
75 Self {
76 conflict_threshold: 0.5,
77 analysis_window: 10,
78 }
79 }
80}
81
82impl GradientConflictAnalyzer {
83 pub fn new(threshold: f64, window: usize) -> Self {
84 Self {
85 conflict_threshold: threshold,
86 analysis_window: window,
87 }
88 }
89
90 pub fn analyze_conflicts(
91 &self,
92 gradient_histories: &HashMap<String, GradientHistory>,
93 ) -> GradientConflictAnalysis {
94 let mut conflicts = Vec::new();
95 let mut layer_gradients: Vec<(String, Vec<f64>)> = Vec::new();
96
97 for (layer_name, history) in gradient_histories {
99 if let Some(recent_gradients) = self.get_recent_gradients(history, self.analysis_window)
100 {
101 layer_gradients.push((layer_name.clone(), recent_gradients));
102 }
103 }
104
105 for i in 0..layer_gradients.len() {
107 for j in (i + 1)..layer_gradients.len() {
108 let (layer1_name, layer1_grads) = &layer_gradients[i];
109 let (layer2_name, layer2_grads) = &layer_gradients[j];
110
111 let conflict_score = self.compute_gradient_conflict(layer1_grads, layer2_grads);
112
113 if conflict_score > self.conflict_threshold {
114 conflicts.push(GradientConflict {
115 layer1: layer1_name.clone(),
116 layer2: layer2_name.clone(),
117 conflict_score,
118 conflict_type: self.classify_conflict_type(conflict_score),
119 recommendations: self.get_conflict_recommendations(conflict_score),
120 });
121 }
122 }
123 }
124
125 let overall_conflict_level = self.compute_overall_conflict_level(&conflicts);
126 let mitigation_strategies = self.generate_conflict_mitigation_strategies(&conflicts);
127
128 GradientConflictAnalysis {
129 total_conflicts: conflicts.len(),
130 conflicts,
131 overall_conflict_level,
132 mitigation_strategies,
133 }
134 }
135
136 fn get_recent_gradients(&self, history: &GradientHistory, count: usize) -> Option<Vec<f64>> {
137 if history.gradient_norms.len() < count {
138 return None;
139 }
140
141 Some(history.gradient_norms.iter().rev().take(count).cloned().collect())
142 }
143
144 fn compute_gradient_conflict(&self, grads1: &[f64], grads2: &[f64]) -> f64 {
145 if grads1.len() != grads2.len() || grads1.is_empty() {
146 return 0.0;
147 }
148
149 let dot_product: f64 = grads1.iter().zip(grads2.iter()).map(|(a, b)| a * b).sum();
151 let norm1: f64 = grads1.iter().map(|x| x * x).sum::<f64>().sqrt();
152 let norm2: f64 = grads2.iter().map(|x| x * x).sum::<f64>().sqrt();
153
154 if norm1 == 0.0 || norm2 == 0.0 {
155 return 1.0; }
157
158 let cosine_similarity = dot_product / (norm1 * norm2);
159
160 (1.0 - cosine_similarity.abs()).max(0.0)
162 }
163
164 fn classify_conflict_type(&self, conflict_score: f64) -> ConflictType {
165 match conflict_score {
166 x if x > 0.8 => ConflictType::Severe,
167 x if x > 0.6 => ConflictType::Moderate,
168 x if x > 0.3 => ConflictType::Mild,
169 _ => ConflictType::None,
170 }
171 }
172
173 fn get_conflict_recommendations(&self, conflict_score: f64) -> Vec<String> {
174 let mut recommendations = Vec::new();
175
176 match conflict_score {
177 x if x > 0.8 => {
178 recommendations.push("Critical gradient conflict detected".to_string());
179 recommendations.push("Consider gradient clipping or normalization".to_string());
180 recommendations.push("Review learning rates for affected layers".to_string());
181 recommendations.push("Consider architectural changes".to_string());
182 },
183 x if x > 0.6 => {
184 recommendations.push("Moderate gradient conflict detected".to_string());
185 recommendations.push("Consider adjusting learning rates".to_string());
186 recommendations.push("Monitor gradient flow patterns".to_string());
187 },
188 x if x > 0.3 => {
189 recommendations.push("Mild gradient conflict detected".to_string());
190 recommendations.push("Continue monitoring conflict patterns".to_string());
191 },
192 _ => {
193 recommendations.push("No significant conflict detected".to_string());
194 },
195 }
196
197 recommendations
198 }
199
200 fn compute_overall_conflict_level(&self, conflicts: &[GradientConflict]) -> ConflictLevel {
201 if conflicts.is_empty() {
202 return ConflictLevel::Low;
203 }
204
205 let severe_conflicts = conflicts
206 .iter()
207 .filter(|c| matches!(c.conflict_type, ConflictType::Severe))
208 .count();
209 let moderate_conflicts = conflicts
210 .iter()
211 .filter(|c| matches!(c.conflict_type, ConflictType::Moderate))
212 .count();
213
214 let total_layers_with_conflicts = self.count_layers_with_conflicts(conflicts);
215
216 if severe_conflicts > 0 || total_layers_with_conflicts > 10 {
217 ConflictLevel::Critical
218 } else if moderate_conflicts > 3 || total_layers_with_conflicts > 5 {
219 ConflictLevel::High
220 } else if moderate_conflicts > 0 || total_layers_with_conflicts > 2 {
221 ConflictLevel::Medium
222 } else {
223 ConflictLevel::Low
224 }
225 }
226
227 fn count_layers_with_conflicts(&self, conflicts: &[GradientConflict]) -> usize {
228 let mut layers = std::collections::HashSet::new();
229 for conflict in conflicts {
230 layers.insert(&conflict.layer1);
231 layers.insert(&conflict.layer2);
232 }
233 layers.len()
234 }
235
236 fn generate_conflict_mitigation_strategies(
237 &self,
238 conflicts: &[GradientConflict],
239 ) -> Vec<ConflictMitigationStrategy> {
240 let mut strategies = Vec::new();
241
242 if conflicts.is_empty() {
243 return strategies;
244 }
245
246 let severe_conflicts = conflicts
248 .iter()
249 .filter(|c| matches!(c.conflict_type, ConflictType::Severe))
250 .count();
251 if severe_conflicts > 0 {
252 strategies.push(ConflictMitigationStrategy {
253 strategy_name: "Gradient Clipping".to_string(),
254 description: "Apply gradient clipping to prevent extreme gradient values"
255 .to_string(),
256 target_conflicts: conflicts
257 .iter()
258 .filter(|c| matches!(c.conflict_type, ConflictType::Severe))
259 .map(|c| format!("{}-{}", c.layer1, c.layer2))
260 .collect(),
261 effectiveness: 0.8,
262 implementation_complexity: MitigationComplexity::Simple,
263 expected_outcome: "Reduced gradient magnitude conflicts".to_string(),
264 });
265 }
266
267 if conflicts.len() > 2 {
269 strategies.push(ConflictMitigationStrategy {
270 strategy_name: "Adaptive Learning Rates".to_string(),
271 description: "Use layer-specific learning rates to balance gradient flows"
272 .to_string(),
273 target_conflicts: conflicts
274 .iter()
275 .map(|c| format!("{}-{}", c.layer1, c.layer2))
276 .collect(),
277 effectiveness: 0.7,
278 implementation_complexity: MitigationComplexity::Moderate,
279 expected_outcome: "Better gradient balance across layers".to_string(),
280 });
281 }
282
283 let high_conflict_count = conflicts
285 .iter()
286 .filter(|c| {
287 matches!(
288 c.conflict_type,
289 ConflictType::Severe | ConflictType::Moderate
290 )
291 })
292 .count();
293
294 if high_conflict_count > 1 {
295 strategies.push(ConflictMitigationStrategy {
296 strategy_name: "Gradient Normalization".to_string(),
297 description: "Normalize gradients to reduce scale conflicts".to_string(),
298 target_conflicts: conflicts
299 .iter()
300 .filter(|c| {
301 matches!(
302 c.conflict_type,
303 ConflictType::Severe | ConflictType::Moderate
304 )
305 })
306 .map(|c| format!("{}-{}", c.layer1, c.layer2))
307 .collect(),
308 effectiveness: 0.6,
309 implementation_complexity: MitigationComplexity::Simple,
310 expected_outcome: "More consistent gradient scales".to_string(),
311 });
312 }
313
314 if severe_conflicts > 3 {
316 strategies.push(ConflictMitigationStrategy {
317 strategy_name: "Architecture Modification".to_string(),
318 description: "Consider residual connections or attention mechanisms".to_string(),
319 target_conflicts: conflicts
320 .iter()
321 .filter(|c| matches!(c.conflict_type, ConflictType::Severe))
322 .map(|c| format!("{}-{}", c.layer1, c.layer2))
323 .collect(),
324 effectiveness: 0.9,
325 implementation_complexity: MitigationComplexity::RequiresArchitectureChange,
326 expected_outcome: "Fundamental improvement in gradient flow".to_string(),
327 });
328 }
329
330 strategies
331 }
332
333 pub fn generate_conflict_report(&self, analysis: &GradientConflictAnalysis) -> ConflictReport {
334 let mut layer_conflict_counts = HashMap::new();
335 #[allow(dead_code)]
336 #[allow(unused_assignments)]
337 let mut most_problematic_pairs = Vec::new();
338
339 for conflict in &analysis.conflicts {
341 *layer_conflict_counts.entry(conflict.layer1.clone()).or_insert(0) += 1;
342 *layer_conflict_counts.entry(conflict.layer2.clone()).or_insert(0) += 1;
343 }
344
345 let mut sorted_conflicts = analysis.conflicts.clone();
347 sorted_conflicts.sort_by(|a, b| {
348 b.conflict_score
349 .partial_cmp(&a.conflict_score)
350 .unwrap_or(std::cmp::Ordering::Equal)
351 });
352 most_problematic_pairs = sorted_conflicts.into_iter().take(5).collect();
353
354 let mut layer_scores: Vec<(String, usize)> = layer_conflict_counts.into_iter().collect();
356 layer_scores.sort_by_key(|item| std::cmp::Reverse(item.1));
357 let most_problematic_layers: Vec<String> =
358 layer_scores.into_iter().take(5).map(|(name, _)| name).collect();
359
360 ConflictReport {
361 total_conflicts: analysis.total_conflicts,
362 overall_level: analysis.overall_conflict_level.clone(),
363 most_problematic_layers,
364 most_problematic_pairs,
365 recommended_strategies: analysis.mitigation_strategies.clone(),
366 summary: self.generate_conflict_summary(analysis),
367 }
368 }
369
370 fn generate_conflict_summary(&self, analysis: &GradientConflictAnalysis) -> String {
371 match analysis.overall_conflict_level {
372 ConflictLevel::Critical => {
373 format!("Critical gradient conflicts detected ({} total). Immediate action required to stabilize training.", analysis.total_conflicts)
374 },
375 ConflictLevel::High => {
376 format!("High level of gradient conflicts ({} total). Consider implementing mitigation strategies.", analysis.total_conflicts)
377 },
378 ConflictLevel::Medium => {
379 format!("Moderate gradient conflicts detected ({} total). Monitor and consider optimization.", analysis.total_conflicts)
380 },
381 ConflictLevel::Low => {
382 format!(
383 "Low conflict level ({} total). Gradient flow appears stable.",
384 analysis.total_conflicts
385 )
386 },
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ConflictReport {
394 pub total_conflicts: usize,
395 pub overall_level: ConflictLevel,
396 pub most_problematic_layers: Vec<String>,
397 pub most_problematic_pairs: Vec<GradientConflict>,
398 pub recommended_strategies: Vec<ConflictMitigationStrategy>,
399 pub summary: String,
400}