Skip to main content

trustformers_debug/
hooks.rs

1//! Debugging hooks for automatic tensor and gradient tracking
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use uuid::Uuid;
7
8/// Hook trigger conditions
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub enum HookTrigger {
11    /// Trigger on every forward pass
12    EveryForward,
13    /// Trigger on every backward pass
14    EveryBackward,
15    /// Trigger every N steps
16    EveryNSteps(usize),
17    /// Trigger when specific conditions are met
18    Conditional(HookCondition),
19    /// Trigger once and then remove
20    Once,
21    /// Trigger on specific layers only
22    LayerSpecific(Vec<String>),
23}
24
25/// Conditions for conditional hooks
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum HookCondition {
28    /// Trigger when loss exceeds threshold
29    LossThreshold {
30        threshold: f64,
31        comparison: Comparison,
32    },
33    /// Trigger when gradient norm exceeds threshold
34    GradientNormThreshold {
35        threshold: f64,
36        comparison: Comparison,
37    },
38    /// Trigger when memory usage exceeds threshold
39    MemoryThreshold { threshold_mb: f64 },
40    /// Trigger on specific training steps
41    StepRange { start: usize, end: usize },
42    /// Custom condition (placeholder for extensibility)
43    Custom(String),
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum Comparison {
48    Greater,
49    Less,
50    Equal,
51    GreaterEqual,
52    LessEqual,
53}
54
55/// Hook action types
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum HookAction {
58    /// Inspect tensor values
59    InspectTensor,
60    /// Track gradient flow
61    TrackGradients,
62    /// Record layer activations
63    RecordActivations,
64    /// Save tensor snapshot to file
65    SaveSnapshot { path: String },
66    /// Generate alert
67    Alert {
68        message: String,
69        severity: AlertSeverity,
70    },
71    /// Execute custom callback
72    CustomCallback { name: String },
73    /// Pause training for manual inspection
74    PauseTraining,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub enum AlertSeverity {
79    Info,
80    Warning,
81    Critical,
82}
83
84/// Hook configuration
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct HookConfig {
87    pub id: Uuid,
88    pub name: String,
89    pub trigger: HookTrigger,
90    pub actions: Vec<HookAction>,
91    pub enabled: bool,
92    pub max_executions: Option<usize>,
93    pub layer_patterns: Vec<String>, // Regex patterns for layer names
94}
95
96/// Hook execution context
97#[derive(Debug)]
98pub struct HookContext {
99    pub step: usize,
100    pub layer_name: String,
101    pub tensor_shape: Vec<usize>,
102    pub is_forward: bool,
103    pub metadata: HashMap<String, String>,
104}
105
106/// Hook execution statistics
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct HookStats {
109    pub hook_id: Uuid,
110    pub hook_name: String,
111    pub total_executions: usize,
112    pub last_execution_step: Option<usize>,
113    pub total_execution_time_ms: f64,
114    pub avg_execution_time_ms: f64,
115    pub errors: usize,
116}
117
118/// Hook execution result
119#[derive(Debug)]
120pub enum HookResult {
121    Success,
122    Error(String),
123    Skipped(String),
124}
125
126/// Callback function type for custom hooks
127pub type HookCallback = Box<dyn Fn(&HookContext, &[u8]) -> Result<()> + Send + Sync>;
128
129/// Hook manager for coordinating debugging hooks
130pub struct HookManager {
131    hooks: HashMap<Uuid, HookConfig>,
132    hook_stats: HashMap<Uuid, HookStats>,
133    callbacks: HashMap<String, HookCallback>,
134    execution_count: HashMap<Uuid, usize>,
135    global_step: usize,
136    enabled: bool,
137}
138
139impl std::fmt::Debug for HookManager {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("HookManager")
142            .field("hooks", &self.hooks)
143            .field("hook_stats", &self.hook_stats)
144            .field("execution_count", &self.execution_count)
145            .field("global_step", &self.global_step)
146            .field("enabled", &self.enabled)
147            .field("callbacks", &format!("{} callbacks", self.callbacks.len()))
148            .finish()
149    }
150}
151
152impl HookManager {
153    /// Create a new hook manager
154    pub fn new() -> Self {
155        Self {
156            hooks: HashMap::new(),
157            hook_stats: HashMap::new(),
158            callbacks: HashMap::new(),
159            execution_count: HashMap::new(),
160            global_step: 0,
161            enabled: true,
162        }
163    }
164
165    /// Register a new hook
166    pub fn register_hook(&mut self, config: HookConfig) -> Result<Uuid> {
167        let hook_id = config.id;
168
169        // Initialize statistics
170        self.hook_stats.insert(
171            hook_id,
172            HookStats {
173                hook_id,
174                hook_name: config.name.clone(),
175                total_executions: 0,
176                last_execution_step: None,
177                total_execution_time_ms: 0.0,
178                avg_execution_time_ms: 0.0,
179                errors: 0,
180            },
181        );
182
183        self.execution_count.insert(hook_id, 0);
184        self.hooks.insert(hook_id, config);
185
186        tracing::debug!("Registered hook {}", hook_id);
187        Ok(hook_id)
188    }
189
190    /// Register a custom callback
191    pub fn register_callback(&mut self, name: String, callback: HookCallback) {
192        self.callbacks.insert(name, callback);
193    }
194
195    /// Remove a hook
196    pub fn remove_hook(&mut self, hook_id: Uuid) -> Option<HookConfig> {
197        self.hook_stats.remove(&hook_id);
198        self.execution_count.remove(&hook_id);
199        self.hooks.remove(&hook_id)
200    }
201
202    /// Enable/disable a specific hook
203    pub fn set_hook_enabled(&mut self, hook_id: Uuid, enabled: bool) -> Result<()> {
204        if let Some(hook) = self.hooks.get_mut(&hook_id) {
205            hook.enabled = enabled;
206            Ok(())
207        } else {
208            Err(anyhow::anyhow!("Hook {} not found", hook_id))
209        }
210    }
211
212    /// Enable/disable all hooks
213    pub fn set_enabled(&mut self, enabled: bool) {
214        self.enabled = enabled;
215    }
216
217    /// Update global step counter
218    pub fn set_step(&mut self, step: usize) {
219        self.global_step = step;
220    }
221
222    /// Execute hooks for a tensor operation
223    pub fn execute_hooks<T>(
224        &mut self,
225        layer_name: &str,
226        tensor_data: &[T],
227        tensor_shape: &[usize],
228        is_forward: bool,
229        metadata: Option<HashMap<String, String>>,
230    ) -> Vec<(Uuid, HookResult)>
231    where
232        T: Clone + 'static,
233    {
234        if !self.enabled {
235            return Vec::new();
236        }
237
238        let context = HookContext {
239            step: self.global_step,
240            layer_name: layer_name.to_string(),
241            tensor_shape: tensor_shape.to_vec(),
242            is_forward,
243            metadata: metadata.unwrap_or_default(),
244        };
245
246        let mut results = Vec::new();
247
248        // Convert tensor data to bytes for callbacks
249        let tensor_bytes = unsafe {
250            std::slice::from_raw_parts(
251                tensor_data.as_ptr() as *const u8,
252                std::mem::size_of_val(tensor_data),
253            )
254        };
255
256        // Collect hook IDs and configs to avoid borrowing conflicts
257        let hooks_to_execute: Vec<(Uuid, HookConfig)> =
258            self.hooks.iter().map(|(id, config)| (*id, config.clone())).collect();
259
260        for (hook_id, hook_config) in hooks_to_execute {
261            if !hook_config.enabled {
262                continue;
263            }
264
265            // Check if we should execute this hook
266            if let Some(should_execute) = self.should_execute_hook(&hook_config, &context) {
267                if !should_execute {
268                    results.push((
269                        hook_id,
270                        HookResult::Skipped("Condition not met".to_string()),
271                    ));
272                    continue;
273                }
274            }
275
276            // Check execution count limits
277            let current_count = self.execution_count.get(&hook_id).copied().unwrap_or(0);
278            if let Some(max_executions) = hook_config.max_executions {
279                if current_count >= max_executions {
280                    results.push((
281                        hook_id,
282                        HookResult::Skipped("Max executions reached".to_string()),
283                    ));
284                    continue;
285                }
286            }
287
288            // Execute hook
289            let start_time = std::time::Instant::now();
290            let result = self.execute_single_hook(&hook_config, &context, tensor_bytes);
291            let execution_time = start_time.elapsed().as_millis() as f64;
292
293            // Update statistics
294            if let Some(stats) = self.hook_stats.get_mut(&hook_id) {
295                stats.total_executions += 1;
296                stats.last_execution_step = Some(self.global_step);
297                stats.total_execution_time_ms += execution_time;
298                stats.avg_execution_time_ms =
299                    stats.total_execution_time_ms / stats.total_executions as f64;
300
301                if matches!(result, HookResult::Error(_)) {
302                    stats.errors += 1;
303                }
304            }
305
306            // Update execution count
307            if let Some(count) = self.execution_count.get_mut(&hook_id) {
308                *count += 1;
309            }
310
311            results.push((hook_id, result));
312        }
313
314        results
315    }
316
317    /// Get hook configuration
318    pub fn get_hook(&self, hook_id: Uuid) -> Option<&HookConfig> {
319        self.hooks.get(&hook_id)
320    }
321
322    /// Get all hooks
323    pub fn get_all_hooks(&self) -> Vec<&HookConfig> {
324        self.hooks.values().collect()
325    }
326
327    /// Get hook statistics
328    pub fn get_hook_stats(&self, hook_id: Uuid) -> Option<&HookStats> {
329        self.hook_stats.get(&hook_id)
330    }
331
332    /// Get all hook statistics
333    pub fn get_all_stats(&self) -> Vec<&HookStats> {
334        self.hook_stats.values().collect()
335    }
336
337    /// Clear all hooks
338    pub fn clear_hooks(&mut self) {
339        self.hooks.clear();
340        self.hook_stats.clear();
341        self.execution_count.clear();
342        self.callbacks.clear();
343    }
344
345    /// Create a convenient tensor inspection hook
346    pub fn create_tensor_inspection_hook(&mut self, layer_patterns: Vec<String>) -> Result<Uuid> {
347        let config = HookConfig {
348            id: Uuid::new_v4(),
349            name: "Tensor Inspector".to_string(),
350            trigger: HookTrigger::EveryForward,
351            actions: vec![HookAction::InspectTensor],
352            enabled: true,
353            max_executions: None,
354            layer_patterns,
355        };
356
357        self.register_hook(config)
358    }
359
360    /// Create a gradient tracking hook
361    pub fn create_gradient_tracking_hook(&mut self, layer_patterns: Vec<String>) -> Result<Uuid> {
362        let config = HookConfig {
363            id: Uuid::new_v4(),
364            name: "Gradient Tracker".to_string(),
365            trigger: HookTrigger::EveryBackward,
366            actions: vec![HookAction::TrackGradients],
367            enabled: true,
368            max_executions: None,
369            layer_patterns,
370        };
371
372        self.register_hook(config)
373    }
374
375    /// Create a conditional alert hook
376    pub fn create_alert_hook(
377        &mut self,
378        condition: HookCondition,
379        message: String,
380        severity: AlertSeverity,
381    ) -> Result<Uuid> {
382        let config = HookConfig {
383            id: Uuid::new_v4(),
384            name: "Alert Hook".to_string(),
385            trigger: HookTrigger::Conditional(condition),
386            actions: vec![HookAction::Alert { message, severity }],
387            enabled: true,
388            max_executions: None,
389            layer_patterns: vec![".*".to_string()], // Match all layers
390        };
391
392        self.register_hook(config)
393    }
394
395    // Private helper methods
396
397    fn should_execute_hook(&self, hook: &HookConfig, context: &HookContext) -> Option<bool> {
398        // Check layer pattern matching
399        if !hook.layer_patterns.is_empty() {
400            let matches_pattern = hook.layer_patterns.iter().any(|pattern| {
401                regex::Regex::new(pattern)
402                    .map(|re| re.is_match(&context.layer_name))
403                    .unwrap_or(false)
404            });
405
406            if !matches_pattern {
407                return Some(false);
408            }
409        }
410
411        match &hook.trigger {
412            HookTrigger::EveryForward => Some(context.is_forward),
413            HookTrigger::EveryBackward => Some(!context.is_forward),
414            HookTrigger::EveryNSteps(n) => Some(context.step.is_multiple_of(*n)),
415            HookTrigger::Conditional(condition) => {
416                Some(self.evaluate_condition(condition, context))
417            },
418            HookTrigger::Once => {
419                let count = self.execution_count.get(&hook.id).copied().unwrap_or(0);
420                Some(count == 0)
421            },
422            HookTrigger::LayerSpecific(layers) => Some(layers.contains(&context.layer_name)),
423        }
424    }
425
426    fn evaluate_condition(&self, condition: &HookCondition, context: &HookContext) -> bool {
427        match condition {
428            HookCondition::StepRange { start, end } => {
429                context.step >= *start && context.step <= *end
430            },
431            HookCondition::Custom(name) => {
432                // For custom conditions, we'd need additional context
433                // This is a placeholder implementation
434                context.metadata.contains_key(name)
435            },
436            // Other conditions would need additional context not available here
437            _ => true,
438        }
439    }
440
441    fn execute_single_hook(
442        &mut self,
443        hook: &HookConfig,
444        context: &HookContext,
445        tensor_data: &[u8],
446    ) -> HookResult {
447        for action in &hook.actions {
448            match self.execute_action(action, context, tensor_data) {
449                Ok(()) => continue,
450                Err(e) => return HookResult::Error(e.to_string()),
451            }
452        }
453        HookResult::Success
454    }
455
456    fn execute_action(
457        &mut self,
458        action: &HookAction,
459        context: &HookContext,
460        tensor_data: &[u8],
461    ) -> Result<()> {
462        match action {
463            HookAction::InspectTensor => {
464                tracing::debug!(
465                    "Inspecting tensor in layer '{}' at step {}",
466                    context.layer_name,
467                    context.step
468                );
469                // In practice, this would call the tensor inspector
470                Ok(())
471            },
472            HookAction::TrackGradients => {
473                tracing::debug!(
474                    "Tracking gradients in layer '{}' at step {}",
475                    context.layer_name,
476                    context.step
477                );
478                // In practice, this would call the gradient debugger
479                Ok(())
480            },
481            HookAction::RecordActivations => {
482                tracing::debug!(
483                    "Recording activations in layer '{}' at step {}",
484                    context.layer_name,
485                    context.step
486                );
487                // In practice, this would record activation statistics
488                Ok(())
489            },
490            HookAction::SaveSnapshot { path } => {
491                let file_path =
492                    format!("{}_{}_step_{}.bin", path, context.layer_name, context.step);
493                std::fs::write(&file_path, tensor_data)?;
494                tracing::info!("Saved tensor snapshot to {}", file_path);
495                Ok(())
496            },
497            HookAction::Alert { message, severity } => {
498                match severity {
499                    AlertSeverity::Info => tracing::info!("Hook Alert: {}", message),
500                    AlertSeverity::Warning => tracing::warn!("Hook Alert: {}", message),
501                    AlertSeverity::Critical => tracing::error!("Hook Alert: {}", message),
502                }
503                Ok(())
504            },
505            HookAction::CustomCallback { name } => {
506                if let Some(callback) = self.callbacks.get(name) {
507                    callback(context, tensor_data)?;
508                } else {
509                    return Err(anyhow::anyhow!("Callback '{}' not found", name));
510                }
511                Ok(())
512            },
513            HookAction::PauseTraining => {
514                tracing::warn!(
515                    "Training paused by hook at step {} in layer '{}'",
516                    context.step,
517                    context.layer_name
518                );
519                // In practice, this would set a flag to pause training
520                Ok(())
521            },
522        }
523    }
524}
525
526impl Default for HookManager {
527    fn default() -> Self {
528        Self::new()
529    }
530}
531
532/// Builder for creating hook configurations
533pub struct HookBuilder {
534    config: HookConfig,
535}
536
537impl HookBuilder {
538    pub fn new(name: &str) -> Self {
539        Self {
540            config: HookConfig {
541                id: Uuid::new_v4(),
542                name: name.to_string(),
543                trigger: HookTrigger::EveryForward,
544                actions: Vec::new(),
545                enabled: true,
546                max_executions: None,
547                layer_patterns: Vec::new(),
548            },
549        }
550    }
551
552    pub fn trigger(mut self, trigger: HookTrigger) -> Self {
553        self.config.trigger = trigger;
554        self
555    }
556
557    pub fn action(mut self, action: HookAction) -> Self {
558        self.config.actions.push(action);
559        self
560    }
561
562    pub fn actions(mut self, actions: Vec<HookAction>) -> Self {
563        self.config.actions = actions;
564        self
565    }
566
567    pub fn max_executions(mut self, max: usize) -> Self {
568        self.config.max_executions = Some(max);
569        self
570    }
571
572    pub fn layer_patterns(mut self, patterns: Vec<String>) -> Self {
573        self.config.layer_patterns = patterns;
574        self
575    }
576
577    pub fn enabled(mut self, enabled: bool) -> Self {
578        self.config.enabled = enabled;
579        self
580    }
581
582    pub fn build(self) -> HookConfig {
583        self.config
584    }
585}
586
587/// Convenience macros for creating hooks
588#[macro_export]
589macro_rules! tensor_hook {
590    ($name:expr, $patterns:expr) => {
591        HookBuilder::new($name)
592            .trigger(HookTrigger::EveryForward)
593            .action(HookAction::InspectTensor)
594            .layer_patterns($patterns)
595            .build()
596    };
597}
598
599#[macro_export]
600macro_rules! gradient_hook {
601    ($name:expr, $patterns:expr) => {
602        HookBuilder::new($name)
603            .trigger(HookTrigger::EveryBackward)
604            .action(HookAction::TrackGradients)
605            .layer_patterns($patterns)
606            .build()
607    };
608}
609
610#[macro_export]
611macro_rules! alert_hook {
612    ($condition:expr, $message:expr, $severity:expr) => {
613        HookBuilder::new("Alert Hook")
614            .trigger(HookTrigger::Conditional($condition))
615            .action(HookAction::Alert {
616                message: $message.to_string(),
617                severity: $severity,
618            })
619            .build()
620    };
621}
622
623// ─────────────────────────────────────────────────────────────────────────────
624// Tests
625// ─────────────────────────────────────────────────────────────────────────────
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630
631    fn make_hook_config(name: &str, trigger: HookTrigger) -> HookConfig {
632        HookConfig {
633            id: Uuid::new_v4(),
634            name: name.to_string(),
635            trigger,
636            actions: vec![HookAction::InspectTensor],
637            enabled: true,
638            max_executions: None,
639            layer_patterns: vec![],
640        }
641    }
642
643    // ── HookManager construction ────────────────────────────────────────────
644
645    #[test]
646    fn test_hook_manager_new_defaults() {
647        let mgr = HookManager::new();
648        assert!(mgr.enabled);
649        assert_eq!(mgr.global_step, 0);
650        assert!(mgr.get_all_hooks().is_empty());
651        assert!(mgr.get_all_stats().is_empty());
652    }
653
654    #[test]
655    fn test_hook_manager_default_equals_new() {
656        let mgr = HookManager::default();
657        assert!(mgr.enabled);
658    }
659
660    // ── register_hook ──────────────────────────────────────────────────────
661
662    #[test]
663    fn test_register_hook_returns_uuid() {
664        let mut mgr = HookManager::new();
665        let config = make_hook_config("test", HookTrigger::EveryForward);
666        let id = config.id;
667        let returned = mgr.register_hook(config).expect("register should succeed");
668        assert_eq!(returned, id);
669    }
670
671    #[test]
672    fn test_register_multiple_hooks() {
673        let mut mgr = HookManager::new();
674        for i in 0..5 {
675            let cfg = make_hook_config(&format!("h{}", i), HookTrigger::EveryForward);
676            mgr.register_hook(cfg).expect("register should succeed");
677        }
678        assert_eq!(mgr.get_all_hooks().len(), 5);
679    }
680
681    #[test]
682    fn test_hook_stats_initialized_on_register() {
683        let mut mgr = HookManager::new();
684        let cfg = make_hook_config("h0", HookTrigger::EveryForward);
685        let id = mgr.register_hook(cfg).expect("register should succeed");
686        let stats = mgr.get_hook_stats(id).expect("stats should exist");
687        assert_eq!(stats.total_executions, 0);
688        assert_eq!(stats.errors, 0);
689    }
690
691    // ── remove_hook ────────────────────────────────────────────────────────
692
693    #[test]
694    fn test_remove_hook_returns_config() {
695        let mut mgr = HookManager::new();
696        let cfg = make_hook_config("remove_me", HookTrigger::EveryBackward);
697        let id = mgr.register_hook(cfg).expect("register");
698        let removed = mgr.remove_hook(id);
699        assert!(removed.is_some());
700        assert_eq!(removed.expect("should be some").name, "remove_me");
701    }
702
703    #[test]
704    fn test_remove_nonexistent_hook_returns_none() {
705        let mut mgr = HookManager::new();
706        let id = Uuid::new_v4();
707        assert!(mgr.remove_hook(id).is_none());
708    }
709
710    // ── set_hook_enabled ───────────────────────────────────────────────────
711
712    #[test]
713    fn test_set_hook_enabled_ok() {
714        let mut mgr = HookManager::new();
715        let cfg = make_hook_config("h", HookTrigger::EveryForward);
716        let id = mgr.register_hook(cfg).expect("register");
717        mgr.set_hook_enabled(id, false).expect("should succeed");
718        let hook = mgr.get_hook(id).expect("hook should exist");
719        assert!(!hook.enabled);
720        mgr.set_hook_enabled(id, true).expect("re-enable");
721        let hook = mgr.get_hook(id).expect("hook should exist");
722        assert!(hook.enabled);
723    }
724
725    #[test]
726    fn test_set_hook_enabled_nonexistent_errors() {
727        let mut mgr = HookManager::new();
728        let result = mgr.set_hook_enabled(Uuid::new_v4(), true);
729        assert!(result.is_err());
730    }
731
732    // ── set_enabled (global) ───────────────────────────────────────────────
733
734    #[test]
735    fn test_global_disable_stops_execution() {
736        let mut mgr = HookManager::new();
737        mgr.set_enabled(false);
738        mgr.register_hook(make_hook_config("h", HookTrigger::EveryForward))
739            .expect("register");
740        let results = mgr.execute_hooks("layer", &[1u8, 2u8], &[2], true, None);
741        assert!(
742            results.is_empty(),
743            "globally disabled manager should execute nothing"
744        );
745    }
746
747    // ── set_step ───────────────────────────────────────────────────────────
748
749    #[test]
750    fn test_set_step_updates_counter() {
751        let mut mgr = HookManager::new();
752        mgr.set_step(42);
753        assert_eq!(mgr.global_step, 42);
754    }
755
756    // ── execute_hooks ──────────────────────────────────────────────────────
757
758    #[test]
759    fn test_execute_hooks_disabled_hook_skipped() {
760        let mut mgr = HookManager::new();
761        let mut cfg = make_hook_config("h", HookTrigger::EveryForward);
762        cfg.enabled = false;
763        mgr.register_hook(cfg).expect("register");
764        let results = mgr.execute_hooks("layer", &[0u8], &[1], true, None);
765        // Disabled hook → no results (the impl skips it without adding an entry)
766        assert_eq!(results.len(), 0);
767    }
768
769    #[test]
770    fn test_execute_hooks_every_forward_fires_on_forward() {
771        let mut mgr = HookManager::new();
772        mgr.register_hook(make_hook_config("h", HookTrigger::EveryForward))
773            .expect("register");
774        let results = mgr.execute_hooks("layer", &[1u8], &[1], true, None);
775        assert_eq!(results.len(), 1);
776    }
777
778    #[test]
779    fn test_execute_hooks_every_forward_skipped_on_backward() {
780        let mut mgr = HookManager::new();
781        // No layer_patterns → pattern check skipped, trigger decides.
782        let cfg = make_hook_config("h", HookTrigger::EveryForward);
783        mgr.register_hook(cfg).expect("register");
784        let results = mgr.execute_hooks("layer", &[1u8], &[1], false, None);
785        // is_forward=false → the hook's should_execute returns Some(false) → Skipped
786        assert_eq!(results.len(), 1);
787        let (_, ref outcome) = results[0];
788        assert!(matches!(outcome, HookResult::Skipped(_)));
789    }
790
791    #[test]
792    fn test_execute_hooks_max_executions_respected() {
793        let mut mgr = HookManager::new();
794        let mut cfg = make_hook_config("once", HookTrigger::EveryForward);
795        cfg.max_executions = Some(1);
796        mgr.register_hook(cfg).expect("register");
797
798        // First execution should succeed
799        let r1 = mgr.execute_hooks("layer", &[1u8], &[1], true, None);
800        assert_eq!(r1.len(), 1);
801        assert!(matches!(r1[0].1, HookResult::Success));
802
803        // Second execution should be Skipped
804        let r2 = mgr.execute_hooks("layer", &[1u8], &[1], true, None);
805        assert_eq!(r2.len(), 1);
806        assert!(matches!(r2[0].1, HookResult::Skipped(_)));
807    }
808
809    // ── clear_hooks ────────────────────────────────────────────────────────
810
811    #[test]
812    fn test_clear_hooks_empties_everything() {
813        let mut mgr = HookManager::new();
814        mgr.register_hook(make_hook_config("h0", HookTrigger::EveryForward))
815            .expect("register");
816        mgr.register_hook(make_hook_config("h1", HookTrigger::EveryBackward))
817            .expect("register");
818        mgr.clear_hooks();
819        assert!(mgr.get_all_hooks().is_empty());
820        assert!(mgr.get_all_stats().is_empty());
821    }
822
823    // ── convenience builders ────────────────────────────────────────────────
824
825    #[test]
826    fn test_create_tensor_inspection_hook() {
827        let mut mgr = HookManager::new();
828        let id = mgr
829            .create_tensor_inspection_hook(vec!["attention.*".to_string()])
830            .expect("should succeed");
831        assert!(mgr.get_hook(id).is_some());
832    }
833
834    #[test]
835    fn test_create_gradient_tracking_hook() {
836        let mut mgr = HookManager::new();
837        let id = mgr
838            .create_gradient_tracking_hook(vec!["fc.*".to_string()])
839            .expect("should succeed");
840        let hook = mgr.get_hook(id).expect("should exist");
841        assert!(matches!(hook.trigger, HookTrigger::EveryBackward));
842    }
843
844    #[test]
845    fn test_create_alert_hook() {
846        let mut mgr = HookManager::new();
847        let cond = HookCondition::StepRange { start: 0, end: 100 };
848        let id = mgr
849            .create_alert_hook(cond, "loss exploded".to_string(), AlertSeverity::Critical)
850            .expect("should succeed");
851        let hook = mgr.get_hook(id).expect("should exist");
852        assert!(matches!(hook.trigger, HookTrigger::Conditional(_)));
853    }
854
855    // ── HookBuilder ────────────────────────────────────────────────────────
856
857    #[test]
858    fn test_hook_builder_basic() {
859        let cfg = HookBuilder::new("my_hook")
860            .trigger(HookTrigger::EveryNSteps(10))
861            .action(HookAction::TrackGradients)
862            .max_executions(50)
863            .layer_patterns(vec!["norm".to_string()])
864            .enabled(true)
865            .build();
866
867        assert_eq!(cfg.name, "my_hook");
868        assert!(matches!(cfg.trigger, HookTrigger::EveryNSteps(10)));
869        assert_eq!(cfg.max_executions, Some(50));
870        assert!(cfg.enabled);
871    }
872
873    // ── enum variants ──────────────────────────────────────────────────────
874
875    #[test]
876    fn test_hook_trigger_variants() {
877        let triggers: Vec<String> = vec![
878            format!("{:?}", HookTrigger::EveryForward),
879            format!("{:?}", HookTrigger::EveryBackward),
880            format!("{:?}", HookTrigger::EveryNSteps(5)),
881            format!("{:?}", HookTrigger::Once),
882            format!("{:?}", HookTrigger::LayerSpecific(vec![])),
883        ];
884        for t in &triggers {
885            assert!(!t.is_empty());
886        }
887    }
888
889    #[test]
890    fn test_hook_action_variants() {
891        let actions: Vec<String> = vec![
892            format!("{:?}", HookAction::InspectTensor),
893            format!("{:?}", HookAction::TrackGradients),
894            format!("{:?}", HookAction::RecordActivations),
895            format!(
896                "{:?}",
897                HookAction::SaveSnapshot {
898                    path: "/tmp".to_string()
899                }
900            ),
901            format!(
902                "{:?}",
903                HookAction::Alert {
904                    message: "x".to_string(),
905                    severity: AlertSeverity::Info
906                }
907            ),
908            format!(
909                "{:?}",
910                HookAction::CustomCallback {
911                    name: "cb".to_string()
912                }
913            ),
914            format!("{:?}", HookAction::PauseTraining),
915        ];
916        for a in &actions {
917            assert!(!a.is_empty());
918        }
919    }
920
921    #[test]
922    fn test_alert_severity_variants() {
923        let severities = [
924            AlertSeverity::Info,
925            AlertSeverity::Warning,
926            AlertSeverity::Critical,
927        ];
928        for s in &severities {
929            assert!(!format!("{:?}", s).is_empty());
930        }
931    }
932
933    #[test]
934    fn test_comparison_variants() {
935        let comps = [
936            Comparison::Greater,
937            Comparison::Less,
938            Comparison::Equal,
939            Comparison::GreaterEqual,
940            Comparison::LessEqual,
941        ];
942        for c in &comps {
943            assert!(!format!("{:?}", c).is_empty());
944        }
945    }
946
947    #[test]
948    fn test_hook_stats_fields() {
949        let id = Uuid::new_v4();
950        let stats = HookStats {
951            hook_id: id,
952            hook_name: "perf_hook".to_string(),
953            total_executions: 100,
954            last_execution_step: Some(99),
955            total_execution_time_ms: 500.0,
956            avg_execution_time_ms: 5.0,
957            errors: 2,
958        };
959        assert_eq!(stats.total_executions, 100);
960        assert_eq!(stats.errors, 2);
961        assert_eq!(stats.last_execution_step, Some(99));
962    }
963}