Skip to main content

trustformers_debug/
hooks.rs

1//! Debugging hooks for automatic tensor and gradient tracking
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use uuid::Uuid;
7
8/// Hook trigger conditions
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub enum HookTrigger {
11    /// Trigger on every forward pass
12    EveryForward,
13    /// Trigger on every backward pass
14    EveryBackward,
15    /// Trigger every N steps
16    EveryNSteps(usize),
17    /// Trigger when specific conditions are met
18    Conditional(HookCondition),
19    /// Trigger once and then remove
20    Once,
21    /// Trigger on specific layers only
22    LayerSpecific(Vec<String>),
23}
24
25/// Conditions for conditional hooks
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum HookCondition {
28    /// Trigger when loss exceeds threshold
29    LossThreshold {
30        threshold: f64,
31        comparison: Comparison,
32    },
33    /// Trigger when gradient norm exceeds threshold
34    GradientNormThreshold {
35        threshold: f64,
36        comparison: Comparison,
37    },
38    /// Trigger when memory usage exceeds threshold
39    MemoryThreshold { threshold_mb: f64 },
40    /// Trigger on specific training steps
41    StepRange { start: usize, end: usize },
42    /// Custom condition (placeholder for extensibility)
43    Custom(String),
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum Comparison {
48    Greater,
49    Less,
50    Equal,
51    GreaterEqual,
52    LessEqual,
53}
54
55/// Hook action types
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum HookAction {
58    /// Inspect tensor values
59    InspectTensor,
60    /// Track gradient flow
61    TrackGradients,
62    /// Record layer activations
63    RecordActivations,
64    /// Save tensor snapshot to file
65    SaveSnapshot { path: String },
66    /// Generate alert
67    Alert {
68        message: String,
69        severity: AlertSeverity,
70    },
71    /// Execute custom callback
72    CustomCallback { name: String },
73    /// Pause training for manual inspection
74    PauseTraining,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub enum AlertSeverity {
79    Info,
80    Warning,
81    Critical,
82}
83
84/// Hook configuration
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct HookConfig {
87    pub id: Uuid,
88    pub name: String,
89    pub trigger: HookTrigger,
90    pub actions: Vec<HookAction>,
91    pub enabled: bool,
92    pub max_executions: Option<usize>,
93    pub layer_patterns: Vec<String>, // Regex patterns for layer names
94}
95
96/// Hook execution context
97#[derive(Debug)]
98pub struct HookContext {
99    pub step: usize,
100    pub layer_name: String,
101    pub tensor_shape: Vec<usize>,
102    pub is_forward: bool,
103    pub metadata: HashMap<String, String>,
104}
105
106/// Hook execution statistics
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct HookStats {
109    pub hook_id: Uuid,
110    pub hook_name: String,
111    pub total_executions: usize,
112    pub last_execution_step: Option<usize>,
113    pub total_execution_time_ms: f64,
114    pub avg_execution_time_ms: f64,
115    pub errors: usize,
116}
117
118/// Hook execution result
119#[derive(Debug)]
120pub enum HookResult {
121    Success,
122    Error(String),
123    Skipped(String),
124}
125
126/// Callback function type for custom hooks
127pub type HookCallback = Box<dyn Fn(&HookContext, &[u8]) -> Result<()> + Send + Sync>;
128
129/// Hook manager for coordinating debugging hooks
130pub struct HookManager {
131    hooks: HashMap<Uuid, HookConfig>,
132    hook_stats: HashMap<Uuid, HookStats>,
133    callbacks: HashMap<String, HookCallback>,
134    execution_count: HashMap<Uuid, usize>,
135    global_step: usize,
136    enabled: bool,
137}
138
139impl std::fmt::Debug for HookManager {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("HookManager")
142            .field("hooks", &self.hooks)
143            .field("hook_stats", &self.hook_stats)
144            .field("execution_count", &self.execution_count)
145            .field("global_step", &self.global_step)
146            .field("enabled", &self.enabled)
147            .field("callbacks", &format!("{} callbacks", self.callbacks.len()))
148            .finish()
149    }
150}
151
152impl HookManager {
153    /// Create a new hook manager
154    pub fn new() -> Self {
155        Self {
156            hooks: HashMap::new(),
157            hook_stats: HashMap::new(),
158            callbacks: HashMap::new(),
159            execution_count: HashMap::new(),
160            global_step: 0,
161            enabled: true,
162        }
163    }
164
165    /// Register a new hook
166    pub fn register_hook(&mut self, config: HookConfig) -> Result<Uuid> {
167        let hook_id = config.id;
168
169        // Initialize statistics
170        self.hook_stats.insert(
171            hook_id,
172            HookStats {
173                hook_id,
174                hook_name: config.name.clone(),
175                total_executions: 0,
176                last_execution_step: None,
177                total_execution_time_ms: 0.0,
178                avg_execution_time_ms: 0.0,
179                errors: 0,
180            },
181        );
182
183        self.execution_count.insert(hook_id, 0);
184        self.hooks.insert(hook_id, config);
185
186        tracing::debug!("Registered hook {}", hook_id);
187        Ok(hook_id)
188    }
189
190    /// Register a custom callback
191    pub fn register_callback(&mut self, name: String, callback: HookCallback) {
192        self.callbacks.insert(name, callback);
193    }
194
195    /// Remove a hook
196    pub fn remove_hook(&mut self, hook_id: Uuid) -> Option<HookConfig> {
197        self.hook_stats.remove(&hook_id);
198        self.execution_count.remove(&hook_id);
199        self.hooks.remove(&hook_id)
200    }
201
202    /// Enable/disable a specific hook
203    pub fn set_hook_enabled(&mut self, hook_id: Uuid, enabled: bool) -> Result<()> {
204        if let Some(hook) = self.hooks.get_mut(&hook_id) {
205            hook.enabled = enabled;
206            Ok(())
207        } else {
208            Err(anyhow::anyhow!("Hook {} not found", hook_id))
209        }
210    }
211
212    /// Enable/disable all hooks
213    pub fn set_enabled(&mut self, enabled: bool) {
214        self.enabled = enabled;
215    }
216
217    /// Update global step counter
218    pub fn set_step(&mut self, step: usize) {
219        self.global_step = step;
220    }
221
222    /// Execute hooks for a tensor operation
223    pub fn execute_hooks<T>(
224        &mut self,
225        layer_name: &str,
226        tensor_data: &[T],
227        tensor_shape: &[usize],
228        is_forward: bool,
229        metadata: Option<HashMap<String, String>>,
230    ) -> Vec<(Uuid, HookResult)>
231    where
232        T: Clone + 'static,
233    {
234        if !self.enabled {
235            return Vec::new();
236        }
237
238        let context = HookContext {
239            step: self.global_step,
240            layer_name: layer_name.to_string(),
241            tensor_shape: tensor_shape.to_vec(),
242            is_forward,
243            metadata: metadata.unwrap_or_default(),
244        };
245
246        let mut results = Vec::new();
247
248        // Convert tensor data to bytes for callbacks
249        let tensor_bytes = unsafe {
250            std::slice::from_raw_parts(
251                tensor_data.as_ptr() as *const u8,
252                std::mem::size_of_val(tensor_data),
253            )
254        };
255
256        // Collect hook IDs and configs to avoid borrowing conflicts
257        let hooks_to_execute: Vec<(Uuid, HookConfig)> =
258            self.hooks.iter().map(|(id, config)| (*id, config.clone())).collect();
259
260        for (hook_id, hook_config) in hooks_to_execute {
261            if !hook_config.enabled {
262                continue;
263            }
264
265            // Check if we should execute this hook
266            if let Some(should_execute) = self.should_execute_hook(&hook_config, &context) {
267                if !should_execute {
268                    results.push((
269                        hook_id,
270                        HookResult::Skipped("Condition not met".to_string()),
271                    ));
272                    continue;
273                }
274            }
275
276            // Check execution count limits
277            let current_count = self.execution_count.get(&hook_id).copied().unwrap_or(0);
278            if let Some(max_executions) = hook_config.max_executions {
279                if current_count >= max_executions {
280                    results.push((
281                        hook_id,
282                        HookResult::Skipped("Max executions reached".to_string()),
283                    ));
284                    continue;
285                }
286            }
287
288            // Execute hook
289            let start_time = std::time::Instant::now();
290            let result = self.execute_single_hook(&hook_config, &context, tensor_bytes);
291            let execution_time = start_time.elapsed().as_millis() as f64;
292
293            // Update statistics
294            if let Some(stats) = self.hook_stats.get_mut(&hook_id) {
295                stats.total_executions += 1;
296                stats.last_execution_step = Some(self.global_step);
297                stats.total_execution_time_ms += execution_time;
298                stats.avg_execution_time_ms =
299                    stats.total_execution_time_ms / stats.total_executions as f64;
300
301                if matches!(result, HookResult::Error(_)) {
302                    stats.errors += 1;
303                }
304            }
305
306            // Update execution count
307            if let Some(count) = self.execution_count.get_mut(&hook_id) {
308                *count += 1;
309            }
310
311            results.push((hook_id, result));
312        }
313
314        results
315    }
316
317    /// Get hook configuration
318    pub fn get_hook(&self, hook_id: Uuid) -> Option<&HookConfig> {
319        self.hooks.get(&hook_id)
320    }
321
322    /// Get all hooks
323    pub fn get_all_hooks(&self) -> Vec<&HookConfig> {
324        self.hooks.values().collect()
325    }
326
327    /// Get hook statistics
328    pub fn get_hook_stats(&self, hook_id: Uuid) -> Option<&HookStats> {
329        self.hook_stats.get(&hook_id)
330    }
331
332    /// Get all hook statistics
333    pub fn get_all_stats(&self) -> Vec<&HookStats> {
334        self.hook_stats.values().collect()
335    }
336
337    /// Clear all hooks
338    pub fn clear_hooks(&mut self) {
339        self.hooks.clear();
340        self.hook_stats.clear();
341        self.execution_count.clear();
342        self.callbacks.clear();
343    }
344
345    /// Create a convenient tensor inspection hook
346    pub fn create_tensor_inspection_hook(&mut self, layer_patterns: Vec<String>) -> Result<Uuid> {
347        let config = HookConfig {
348            id: Uuid::new_v4(),
349            name: "Tensor Inspector".to_string(),
350            trigger: HookTrigger::EveryForward,
351            actions: vec![HookAction::InspectTensor],
352            enabled: true,
353            max_executions: None,
354            layer_patterns,
355        };
356
357        self.register_hook(config)
358    }
359
360    /// Create a gradient tracking hook
361    pub fn create_gradient_tracking_hook(&mut self, layer_patterns: Vec<String>) -> Result<Uuid> {
362        let config = HookConfig {
363            id: Uuid::new_v4(),
364            name: "Gradient Tracker".to_string(),
365            trigger: HookTrigger::EveryBackward,
366            actions: vec![HookAction::TrackGradients],
367            enabled: true,
368            max_executions: None,
369            layer_patterns,
370        };
371
372        self.register_hook(config)
373    }
374
375    /// Create a conditional alert hook
376    pub fn create_alert_hook(
377        &mut self,
378        condition: HookCondition,
379        message: String,
380        severity: AlertSeverity,
381    ) -> Result<Uuid> {
382        let config = HookConfig {
383            id: Uuid::new_v4(),
384            name: "Alert Hook".to_string(),
385            trigger: HookTrigger::Conditional(condition),
386            actions: vec![HookAction::Alert { message, severity }],
387            enabled: true,
388            max_executions: None,
389            layer_patterns: vec![".*".to_string()], // Match all layers
390        };
391
392        self.register_hook(config)
393    }
394
395    // Private helper methods
396
397    fn should_execute_hook(&self, hook: &HookConfig, context: &HookContext) -> Option<bool> {
398        // Check layer pattern matching
399        if !hook.layer_patterns.is_empty() {
400            let matches_pattern = hook.layer_patterns.iter().any(|pattern| {
401                regex::Regex::new(pattern)
402                    .map(|re| re.is_match(&context.layer_name))
403                    .unwrap_or(false)
404            });
405
406            if !matches_pattern {
407                return Some(false);
408            }
409        }
410
411        match &hook.trigger {
412            HookTrigger::EveryForward => Some(context.is_forward),
413            HookTrigger::EveryBackward => Some(!context.is_forward),
414            HookTrigger::EveryNSteps(n) => Some(context.step % n == 0),
415            HookTrigger::Conditional(condition) => {
416                Some(self.evaluate_condition(condition, context))
417            },
418            HookTrigger::Once => {
419                let count = self.execution_count.get(&hook.id).copied().unwrap_or(0);
420                Some(count == 0)
421            },
422            HookTrigger::LayerSpecific(layers) => Some(layers.contains(&context.layer_name)),
423        }
424    }
425
426    fn evaluate_condition(&self, condition: &HookCondition, context: &HookContext) -> bool {
427        match condition {
428            HookCondition::StepRange { start, end } => {
429                context.step >= *start && context.step <= *end
430            },
431            HookCondition::Custom(name) => {
432                // For custom conditions, we'd need additional context
433                // This is a placeholder implementation
434                context.metadata.contains_key(name)
435            },
436            // Other conditions would need additional context not available here
437            _ => true,
438        }
439    }
440
441    fn execute_single_hook(
442        &mut self,
443        hook: &HookConfig,
444        context: &HookContext,
445        tensor_data: &[u8],
446    ) -> HookResult {
447        for action in &hook.actions {
448            match self.execute_action(action, context, tensor_data) {
449                Ok(()) => continue,
450                Err(e) => return HookResult::Error(e.to_string()),
451            }
452        }
453        HookResult::Success
454    }
455
456    fn execute_action(
457        &mut self,
458        action: &HookAction,
459        context: &HookContext,
460        tensor_data: &[u8],
461    ) -> Result<()> {
462        match action {
463            HookAction::InspectTensor => {
464                tracing::debug!(
465                    "Inspecting tensor in layer '{}' at step {}",
466                    context.layer_name,
467                    context.step
468                );
469                // In practice, this would call the tensor inspector
470                Ok(())
471            },
472            HookAction::TrackGradients => {
473                tracing::debug!(
474                    "Tracking gradients in layer '{}' at step {}",
475                    context.layer_name,
476                    context.step
477                );
478                // In practice, this would call the gradient debugger
479                Ok(())
480            },
481            HookAction::RecordActivations => {
482                tracing::debug!(
483                    "Recording activations in layer '{}' at step {}",
484                    context.layer_name,
485                    context.step
486                );
487                // In practice, this would record activation statistics
488                Ok(())
489            },
490            HookAction::SaveSnapshot { path } => {
491                let file_path =
492                    format!("{}_{}_step_{}.bin", path, context.layer_name, context.step);
493                std::fs::write(&file_path, tensor_data)?;
494                tracing::info!("Saved tensor snapshot to {}", file_path);
495                Ok(())
496            },
497            HookAction::Alert { message, severity } => {
498                match severity {
499                    AlertSeverity::Info => tracing::info!("Hook Alert: {}", message),
500                    AlertSeverity::Warning => tracing::warn!("Hook Alert: {}", message),
501                    AlertSeverity::Critical => tracing::error!("Hook Alert: {}", message),
502                }
503                Ok(())
504            },
505            HookAction::CustomCallback { name } => {
506                if let Some(callback) = self.callbacks.get(name) {
507                    callback(context, tensor_data)?;
508                } else {
509                    return Err(anyhow::anyhow!("Callback '{}' not found", name));
510                }
511                Ok(())
512            },
513            HookAction::PauseTraining => {
514                tracing::warn!(
515                    "Training paused by hook at step {} in layer '{}'",
516                    context.step,
517                    context.layer_name
518                );
519                // In practice, this would set a flag to pause training
520                Ok(())
521            },
522        }
523    }
524}
525
526impl Default for HookManager {
527    fn default() -> Self {
528        Self::new()
529    }
530}
531
532/// Builder for creating hook configurations
533pub struct HookBuilder {
534    config: HookConfig,
535}
536
537impl HookBuilder {
538    pub fn new(name: &str) -> Self {
539        Self {
540            config: HookConfig {
541                id: Uuid::new_v4(),
542                name: name.to_string(),
543                trigger: HookTrigger::EveryForward,
544                actions: Vec::new(),
545                enabled: true,
546                max_executions: None,
547                layer_patterns: Vec::new(),
548            },
549        }
550    }
551
552    pub fn trigger(mut self, trigger: HookTrigger) -> Self {
553        self.config.trigger = trigger;
554        self
555    }
556
557    pub fn action(mut self, action: HookAction) -> Self {
558        self.config.actions.push(action);
559        self
560    }
561
562    pub fn actions(mut self, actions: Vec<HookAction>) -> Self {
563        self.config.actions = actions;
564        self
565    }
566
567    pub fn max_executions(mut self, max: usize) -> Self {
568        self.config.max_executions = Some(max);
569        self
570    }
571
572    pub fn layer_patterns(mut self, patterns: Vec<String>) -> Self {
573        self.config.layer_patterns = patterns;
574        self
575    }
576
577    pub fn enabled(mut self, enabled: bool) -> Self {
578        self.config.enabled = enabled;
579        self
580    }
581
582    pub fn build(self) -> HookConfig {
583        self.config
584    }
585}
586
587/// Convenience macros for creating hooks
588#[macro_export]
589macro_rules! tensor_hook {
590    ($name:expr, $patterns:expr) => {
591        HookBuilder::new($name)
592            .trigger(HookTrigger::EveryForward)
593            .action(HookAction::InspectTensor)
594            .layer_patterns($patterns)
595            .build()
596    };
597}
598
599#[macro_export]
600macro_rules! gradient_hook {
601    ($name:expr, $patterns:expr) => {
602        HookBuilder::new($name)
603            .trigger(HookTrigger::EveryBackward)
604            .action(HookAction::TrackGradients)
605            .layer_patterns($patterns)
606            .build()
607    };
608}
609
610#[macro_export]
611macro_rules! alert_hook {
612    ($condition:expr, $message:expr, $severity:expr) => {
613        HookBuilder::new("Alert Hook")
614            .trigger(HookTrigger::Conditional($condition))
615            .action(HookAction::Alert {
616                message: $message.to_string(),
617                severity: $severity,
618            })
619            .build()
620    };
621}