1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub enum HookTrigger {
11 EveryForward,
13 EveryBackward,
15 EveryNSteps(usize),
17 Conditional(HookCondition),
19 Once,
21 LayerSpecific(Vec<String>),
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum HookCondition {
28 LossThreshold {
30 threshold: f64,
31 comparison: Comparison,
32 },
33 GradientNormThreshold {
35 threshold: f64,
36 comparison: Comparison,
37 },
38 MemoryThreshold { threshold_mb: f64 },
40 StepRange { start: usize, end: usize },
42 Custom(String),
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum Comparison {
48 Greater,
49 Less,
50 Equal,
51 GreaterEqual,
52 LessEqual,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum HookAction {
58 InspectTensor,
60 TrackGradients,
62 RecordActivations,
64 SaveSnapshot { path: String },
66 Alert {
68 message: String,
69 severity: AlertSeverity,
70 },
71 CustomCallback { name: String },
73 PauseTraining,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub enum AlertSeverity {
79 Info,
80 Warning,
81 Critical,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct HookConfig {
87 pub id: Uuid,
88 pub name: String,
89 pub trigger: HookTrigger,
90 pub actions: Vec<HookAction>,
91 pub enabled: bool,
92 pub max_executions: Option<usize>,
93 pub layer_patterns: Vec<String>, }
95
96#[derive(Debug)]
98pub struct HookContext {
99 pub step: usize,
100 pub layer_name: String,
101 pub tensor_shape: Vec<usize>,
102 pub is_forward: bool,
103 pub metadata: HashMap<String, String>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct HookStats {
109 pub hook_id: Uuid,
110 pub hook_name: String,
111 pub total_executions: usize,
112 pub last_execution_step: Option<usize>,
113 pub total_execution_time_ms: f64,
114 pub avg_execution_time_ms: f64,
115 pub errors: usize,
116}
117
118#[derive(Debug)]
120pub enum HookResult {
121 Success,
122 Error(String),
123 Skipped(String),
124}
125
126pub type HookCallback = Box<dyn Fn(&HookContext, &[u8]) -> Result<()> + Send + Sync>;
128
129pub struct HookManager {
131 hooks: HashMap<Uuid, HookConfig>,
132 hook_stats: HashMap<Uuid, HookStats>,
133 callbacks: HashMap<String, HookCallback>,
134 execution_count: HashMap<Uuid, usize>,
135 global_step: usize,
136 enabled: bool,
137}
138
139impl std::fmt::Debug for HookManager {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.debug_struct("HookManager")
142 .field("hooks", &self.hooks)
143 .field("hook_stats", &self.hook_stats)
144 .field("execution_count", &self.execution_count)
145 .field("global_step", &self.global_step)
146 .field("enabled", &self.enabled)
147 .field("callbacks", &format!("{} callbacks", self.callbacks.len()))
148 .finish()
149 }
150}
151
152impl HookManager {
153 pub fn new() -> Self {
155 Self {
156 hooks: HashMap::new(),
157 hook_stats: HashMap::new(),
158 callbacks: HashMap::new(),
159 execution_count: HashMap::new(),
160 global_step: 0,
161 enabled: true,
162 }
163 }
164
165 pub fn register_hook(&mut self, config: HookConfig) -> Result<Uuid> {
167 let hook_id = config.id;
168
169 self.hook_stats.insert(
171 hook_id,
172 HookStats {
173 hook_id,
174 hook_name: config.name.clone(),
175 total_executions: 0,
176 last_execution_step: None,
177 total_execution_time_ms: 0.0,
178 avg_execution_time_ms: 0.0,
179 errors: 0,
180 },
181 );
182
183 self.execution_count.insert(hook_id, 0);
184 self.hooks.insert(hook_id, config);
185
186 tracing::debug!("Registered hook {}", hook_id);
187 Ok(hook_id)
188 }
189
190 pub fn register_callback(&mut self, name: String, callback: HookCallback) {
192 self.callbacks.insert(name, callback);
193 }
194
195 pub fn remove_hook(&mut self, hook_id: Uuid) -> Option<HookConfig> {
197 self.hook_stats.remove(&hook_id);
198 self.execution_count.remove(&hook_id);
199 self.hooks.remove(&hook_id)
200 }
201
202 pub fn set_hook_enabled(&mut self, hook_id: Uuid, enabled: bool) -> Result<()> {
204 if let Some(hook) = self.hooks.get_mut(&hook_id) {
205 hook.enabled = enabled;
206 Ok(())
207 } else {
208 Err(anyhow::anyhow!("Hook {} not found", hook_id))
209 }
210 }
211
212 pub fn set_enabled(&mut self, enabled: bool) {
214 self.enabled = enabled;
215 }
216
217 pub fn set_step(&mut self, step: usize) {
219 self.global_step = step;
220 }
221
222 pub fn execute_hooks<T>(
224 &mut self,
225 layer_name: &str,
226 tensor_data: &[T],
227 tensor_shape: &[usize],
228 is_forward: bool,
229 metadata: Option<HashMap<String, String>>,
230 ) -> Vec<(Uuid, HookResult)>
231 where
232 T: Clone + 'static,
233 {
234 if !self.enabled {
235 return Vec::new();
236 }
237
238 let context = HookContext {
239 step: self.global_step,
240 layer_name: layer_name.to_string(),
241 tensor_shape: tensor_shape.to_vec(),
242 is_forward,
243 metadata: metadata.unwrap_or_default(),
244 };
245
246 let mut results = Vec::new();
247
248 let tensor_bytes = unsafe {
250 std::slice::from_raw_parts(
251 tensor_data.as_ptr() as *const u8,
252 std::mem::size_of_val(tensor_data),
253 )
254 };
255
256 let hooks_to_execute: Vec<(Uuid, HookConfig)> =
258 self.hooks.iter().map(|(id, config)| (*id, config.clone())).collect();
259
260 for (hook_id, hook_config) in hooks_to_execute {
261 if !hook_config.enabled {
262 continue;
263 }
264
265 if let Some(should_execute) = self.should_execute_hook(&hook_config, &context) {
267 if !should_execute {
268 results.push((
269 hook_id,
270 HookResult::Skipped("Condition not met".to_string()),
271 ));
272 continue;
273 }
274 }
275
276 let current_count = self.execution_count.get(&hook_id).copied().unwrap_or(0);
278 if let Some(max_executions) = hook_config.max_executions {
279 if current_count >= max_executions {
280 results.push((
281 hook_id,
282 HookResult::Skipped("Max executions reached".to_string()),
283 ));
284 continue;
285 }
286 }
287
288 let start_time = std::time::Instant::now();
290 let result = self.execute_single_hook(&hook_config, &context, tensor_bytes);
291 let execution_time = start_time.elapsed().as_millis() as f64;
292
293 if let Some(stats) = self.hook_stats.get_mut(&hook_id) {
295 stats.total_executions += 1;
296 stats.last_execution_step = Some(self.global_step);
297 stats.total_execution_time_ms += execution_time;
298 stats.avg_execution_time_ms =
299 stats.total_execution_time_ms / stats.total_executions as f64;
300
301 if matches!(result, HookResult::Error(_)) {
302 stats.errors += 1;
303 }
304 }
305
306 if let Some(count) = self.execution_count.get_mut(&hook_id) {
308 *count += 1;
309 }
310
311 results.push((hook_id, result));
312 }
313
314 results
315 }
316
317 pub fn get_hook(&self, hook_id: Uuid) -> Option<&HookConfig> {
319 self.hooks.get(&hook_id)
320 }
321
322 pub fn get_all_hooks(&self) -> Vec<&HookConfig> {
324 self.hooks.values().collect()
325 }
326
327 pub fn get_hook_stats(&self, hook_id: Uuid) -> Option<&HookStats> {
329 self.hook_stats.get(&hook_id)
330 }
331
332 pub fn get_all_stats(&self) -> Vec<&HookStats> {
334 self.hook_stats.values().collect()
335 }
336
337 pub fn clear_hooks(&mut self) {
339 self.hooks.clear();
340 self.hook_stats.clear();
341 self.execution_count.clear();
342 self.callbacks.clear();
343 }
344
345 pub fn create_tensor_inspection_hook(&mut self, layer_patterns: Vec<String>) -> Result<Uuid> {
347 let config = HookConfig {
348 id: Uuid::new_v4(),
349 name: "Tensor Inspector".to_string(),
350 trigger: HookTrigger::EveryForward,
351 actions: vec![HookAction::InspectTensor],
352 enabled: true,
353 max_executions: None,
354 layer_patterns,
355 };
356
357 self.register_hook(config)
358 }
359
360 pub fn create_gradient_tracking_hook(&mut self, layer_patterns: Vec<String>) -> Result<Uuid> {
362 let config = HookConfig {
363 id: Uuid::new_v4(),
364 name: "Gradient Tracker".to_string(),
365 trigger: HookTrigger::EveryBackward,
366 actions: vec![HookAction::TrackGradients],
367 enabled: true,
368 max_executions: None,
369 layer_patterns,
370 };
371
372 self.register_hook(config)
373 }
374
375 pub fn create_alert_hook(
377 &mut self,
378 condition: HookCondition,
379 message: String,
380 severity: AlertSeverity,
381 ) -> Result<Uuid> {
382 let config = HookConfig {
383 id: Uuid::new_v4(),
384 name: "Alert Hook".to_string(),
385 trigger: HookTrigger::Conditional(condition),
386 actions: vec![HookAction::Alert { message, severity }],
387 enabled: true,
388 max_executions: None,
389 layer_patterns: vec![".*".to_string()], };
391
392 self.register_hook(config)
393 }
394
395 fn should_execute_hook(&self, hook: &HookConfig, context: &HookContext) -> Option<bool> {
398 if !hook.layer_patterns.is_empty() {
400 let matches_pattern = hook.layer_patterns.iter().any(|pattern| {
401 regex::Regex::new(pattern)
402 .map(|re| re.is_match(&context.layer_name))
403 .unwrap_or(false)
404 });
405
406 if !matches_pattern {
407 return Some(false);
408 }
409 }
410
411 match &hook.trigger {
412 HookTrigger::EveryForward => Some(context.is_forward),
413 HookTrigger::EveryBackward => Some(!context.is_forward),
414 HookTrigger::EveryNSteps(n) => Some(context.step.is_multiple_of(*n)),
415 HookTrigger::Conditional(condition) => {
416 Some(self.evaluate_condition(condition, context))
417 },
418 HookTrigger::Once => {
419 let count = self.execution_count.get(&hook.id).copied().unwrap_or(0);
420 Some(count == 0)
421 },
422 HookTrigger::LayerSpecific(layers) => Some(layers.contains(&context.layer_name)),
423 }
424 }
425
426 fn evaluate_condition(&self, condition: &HookCondition, context: &HookContext) -> bool {
427 match condition {
428 HookCondition::StepRange { start, end } => {
429 context.step >= *start && context.step <= *end
430 },
431 HookCondition::Custom(name) => {
432 context.metadata.contains_key(name)
435 },
436 _ => true,
438 }
439 }
440
441 fn execute_single_hook(
442 &mut self,
443 hook: &HookConfig,
444 context: &HookContext,
445 tensor_data: &[u8],
446 ) -> HookResult {
447 for action in &hook.actions {
448 match self.execute_action(action, context, tensor_data) {
449 Ok(()) => continue,
450 Err(e) => return HookResult::Error(e.to_string()),
451 }
452 }
453 HookResult::Success
454 }
455
456 fn execute_action(
457 &mut self,
458 action: &HookAction,
459 context: &HookContext,
460 tensor_data: &[u8],
461 ) -> Result<()> {
462 match action {
463 HookAction::InspectTensor => {
464 tracing::debug!(
465 "Inspecting tensor in layer '{}' at step {}",
466 context.layer_name,
467 context.step
468 );
469 Ok(())
471 },
472 HookAction::TrackGradients => {
473 tracing::debug!(
474 "Tracking gradients in layer '{}' at step {}",
475 context.layer_name,
476 context.step
477 );
478 Ok(())
480 },
481 HookAction::RecordActivations => {
482 tracing::debug!(
483 "Recording activations in layer '{}' at step {}",
484 context.layer_name,
485 context.step
486 );
487 Ok(())
489 },
490 HookAction::SaveSnapshot { path } => {
491 let file_path =
492 format!("{}_{}_step_{}.bin", path, context.layer_name, context.step);
493 std::fs::write(&file_path, tensor_data)?;
494 tracing::info!("Saved tensor snapshot to {}", file_path);
495 Ok(())
496 },
497 HookAction::Alert { message, severity } => {
498 match severity {
499 AlertSeverity::Info => tracing::info!("Hook Alert: {}", message),
500 AlertSeverity::Warning => tracing::warn!("Hook Alert: {}", message),
501 AlertSeverity::Critical => tracing::error!("Hook Alert: {}", message),
502 }
503 Ok(())
504 },
505 HookAction::CustomCallback { name } => {
506 if let Some(callback) = self.callbacks.get(name) {
507 callback(context, tensor_data)?;
508 } else {
509 return Err(anyhow::anyhow!("Callback '{}' not found", name));
510 }
511 Ok(())
512 },
513 HookAction::PauseTraining => {
514 tracing::warn!(
515 "Training paused by hook at step {} in layer '{}'",
516 context.step,
517 context.layer_name
518 );
519 Ok(())
521 },
522 }
523 }
524}
525
526impl Default for HookManager {
527 fn default() -> Self {
528 Self::new()
529 }
530}
531
532pub struct HookBuilder {
534 config: HookConfig,
535}
536
537impl HookBuilder {
538 pub fn new(name: &str) -> Self {
539 Self {
540 config: HookConfig {
541 id: Uuid::new_v4(),
542 name: name.to_string(),
543 trigger: HookTrigger::EveryForward,
544 actions: Vec::new(),
545 enabled: true,
546 max_executions: None,
547 layer_patterns: Vec::new(),
548 },
549 }
550 }
551
552 pub fn trigger(mut self, trigger: HookTrigger) -> Self {
553 self.config.trigger = trigger;
554 self
555 }
556
557 pub fn action(mut self, action: HookAction) -> Self {
558 self.config.actions.push(action);
559 self
560 }
561
562 pub fn actions(mut self, actions: Vec<HookAction>) -> Self {
563 self.config.actions = actions;
564 self
565 }
566
567 pub fn max_executions(mut self, max: usize) -> Self {
568 self.config.max_executions = Some(max);
569 self
570 }
571
572 pub fn layer_patterns(mut self, patterns: Vec<String>) -> Self {
573 self.config.layer_patterns = patterns;
574 self
575 }
576
577 pub fn enabled(mut self, enabled: bool) -> Self {
578 self.config.enabled = enabled;
579 self
580 }
581
582 pub fn build(self) -> HookConfig {
583 self.config
584 }
585}
586
587#[macro_export]
589macro_rules! tensor_hook {
590 ($name:expr, $patterns:expr) => {
591 HookBuilder::new($name)
592 .trigger(HookTrigger::EveryForward)
593 .action(HookAction::InspectTensor)
594 .layer_patterns($patterns)
595 .build()
596 };
597}
598
599#[macro_export]
600macro_rules! gradient_hook {
601 ($name:expr, $patterns:expr) => {
602 HookBuilder::new($name)
603 .trigger(HookTrigger::EveryBackward)
604 .action(HookAction::TrackGradients)
605 .layer_patterns($patterns)
606 .build()
607 };
608}
609
610#[macro_export]
611macro_rules! alert_hook {
612 ($condition:expr, $message:expr, $severity:expr) => {
613 HookBuilder::new("Alert Hook")
614 .trigger(HookTrigger::Conditional($condition))
615 .action(HookAction::Alert {
616 message: $message.to_string(),
617 severity: $severity,
618 })
619 .build()
620 };
621}
622
623#[cfg(test)]
628mod tests {
629 use super::*;
630
631 fn make_hook_config(name: &str, trigger: HookTrigger) -> HookConfig {
632 HookConfig {
633 id: Uuid::new_v4(),
634 name: name.to_string(),
635 trigger,
636 actions: vec![HookAction::InspectTensor],
637 enabled: true,
638 max_executions: None,
639 layer_patterns: vec![],
640 }
641 }
642
643 #[test]
646 fn test_hook_manager_new_defaults() {
647 let mgr = HookManager::new();
648 assert!(mgr.enabled);
649 assert_eq!(mgr.global_step, 0);
650 assert!(mgr.get_all_hooks().is_empty());
651 assert!(mgr.get_all_stats().is_empty());
652 }
653
654 #[test]
655 fn test_hook_manager_default_equals_new() {
656 let mgr = HookManager::default();
657 assert!(mgr.enabled);
658 }
659
660 #[test]
663 fn test_register_hook_returns_uuid() {
664 let mut mgr = HookManager::new();
665 let config = make_hook_config("test", HookTrigger::EveryForward);
666 let id = config.id;
667 let returned = mgr.register_hook(config).expect("register should succeed");
668 assert_eq!(returned, id);
669 }
670
671 #[test]
672 fn test_register_multiple_hooks() {
673 let mut mgr = HookManager::new();
674 for i in 0..5 {
675 let cfg = make_hook_config(&format!("h{}", i), HookTrigger::EveryForward);
676 mgr.register_hook(cfg).expect("register should succeed");
677 }
678 assert_eq!(mgr.get_all_hooks().len(), 5);
679 }
680
681 #[test]
682 fn test_hook_stats_initialized_on_register() {
683 let mut mgr = HookManager::new();
684 let cfg = make_hook_config("h0", HookTrigger::EveryForward);
685 let id = mgr.register_hook(cfg).expect("register should succeed");
686 let stats = mgr.get_hook_stats(id).expect("stats should exist");
687 assert_eq!(stats.total_executions, 0);
688 assert_eq!(stats.errors, 0);
689 }
690
691 #[test]
694 fn test_remove_hook_returns_config() {
695 let mut mgr = HookManager::new();
696 let cfg = make_hook_config("remove_me", HookTrigger::EveryBackward);
697 let id = mgr.register_hook(cfg).expect("register");
698 let removed = mgr.remove_hook(id);
699 assert!(removed.is_some());
700 assert_eq!(removed.expect("should be some").name, "remove_me");
701 }
702
703 #[test]
704 fn test_remove_nonexistent_hook_returns_none() {
705 let mut mgr = HookManager::new();
706 let id = Uuid::new_v4();
707 assert!(mgr.remove_hook(id).is_none());
708 }
709
710 #[test]
713 fn test_set_hook_enabled_ok() {
714 let mut mgr = HookManager::new();
715 let cfg = make_hook_config("h", HookTrigger::EveryForward);
716 let id = mgr.register_hook(cfg).expect("register");
717 mgr.set_hook_enabled(id, false).expect("should succeed");
718 let hook = mgr.get_hook(id).expect("hook should exist");
719 assert!(!hook.enabled);
720 mgr.set_hook_enabled(id, true).expect("re-enable");
721 let hook = mgr.get_hook(id).expect("hook should exist");
722 assert!(hook.enabled);
723 }
724
725 #[test]
726 fn test_set_hook_enabled_nonexistent_errors() {
727 let mut mgr = HookManager::new();
728 let result = mgr.set_hook_enabled(Uuid::new_v4(), true);
729 assert!(result.is_err());
730 }
731
732 #[test]
735 fn test_global_disable_stops_execution() {
736 let mut mgr = HookManager::new();
737 mgr.set_enabled(false);
738 mgr.register_hook(make_hook_config("h", HookTrigger::EveryForward))
739 .expect("register");
740 let results = mgr.execute_hooks("layer", &[1u8, 2u8], &[2], true, None);
741 assert!(
742 results.is_empty(),
743 "globally disabled manager should execute nothing"
744 );
745 }
746
747 #[test]
750 fn test_set_step_updates_counter() {
751 let mut mgr = HookManager::new();
752 mgr.set_step(42);
753 assert_eq!(mgr.global_step, 42);
754 }
755
756 #[test]
759 fn test_execute_hooks_disabled_hook_skipped() {
760 let mut mgr = HookManager::new();
761 let mut cfg = make_hook_config("h", HookTrigger::EveryForward);
762 cfg.enabled = false;
763 mgr.register_hook(cfg).expect("register");
764 let results = mgr.execute_hooks("layer", &[0u8], &[1], true, None);
765 assert_eq!(results.len(), 0);
767 }
768
769 #[test]
770 fn test_execute_hooks_every_forward_fires_on_forward() {
771 let mut mgr = HookManager::new();
772 mgr.register_hook(make_hook_config("h", HookTrigger::EveryForward))
773 .expect("register");
774 let results = mgr.execute_hooks("layer", &[1u8], &[1], true, None);
775 assert_eq!(results.len(), 1);
776 }
777
778 #[test]
779 fn test_execute_hooks_every_forward_skipped_on_backward() {
780 let mut mgr = HookManager::new();
781 let cfg = make_hook_config("h", HookTrigger::EveryForward);
783 mgr.register_hook(cfg).expect("register");
784 let results = mgr.execute_hooks("layer", &[1u8], &[1], false, None);
785 assert_eq!(results.len(), 1);
787 let (_, ref outcome) = results[0];
788 assert!(matches!(outcome, HookResult::Skipped(_)));
789 }
790
791 #[test]
792 fn test_execute_hooks_max_executions_respected() {
793 let mut mgr = HookManager::new();
794 let mut cfg = make_hook_config("once", HookTrigger::EveryForward);
795 cfg.max_executions = Some(1);
796 mgr.register_hook(cfg).expect("register");
797
798 let r1 = mgr.execute_hooks("layer", &[1u8], &[1], true, None);
800 assert_eq!(r1.len(), 1);
801 assert!(matches!(r1[0].1, HookResult::Success));
802
803 let r2 = mgr.execute_hooks("layer", &[1u8], &[1], true, None);
805 assert_eq!(r2.len(), 1);
806 assert!(matches!(r2[0].1, HookResult::Skipped(_)));
807 }
808
809 #[test]
812 fn test_clear_hooks_empties_everything() {
813 let mut mgr = HookManager::new();
814 mgr.register_hook(make_hook_config("h0", HookTrigger::EveryForward))
815 .expect("register");
816 mgr.register_hook(make_hook_config("h1", HookTrigger::EveryBackward))
817 .expect("register");
818 mgr.clear_hooks();
819 assert!(mgr.get_all_hooks().is_empty());
820 assert!(mgr.get_all_stats().is_empty());
821 }
822
823 #[test]
826 fn test_create_tensor_inspection_hook() {
827 let mut mgr = HookManager::new();
828 let id = mgr
829 .create_tensor_inspection_hook(vec!["attention.*".to_string()])
830 .expect("should succeed");
831 assert!(mgr.get_hook(id).is_some());
832 }
833
834 #[test]
835 fn test_create_gradient_tracking_hook() {
836 let mut mgr = HookManager::new();
837 let id = mgr
838 .create_gradient_tracking_hook(vec!["fc.*".to_string()])
839 .expect("should succeed");
840 let hook = mgr.get_hook(id).expect("should exist");
841 assert!(matches!(hook.trigger, HookTrigger::EveryBackward));
842 }
843
844 #[test]
845 fn test_create_alert_hook() {
846 let mut mgr = HookManager::new();
847 let cond = HookCondition::StepRange { start: 0, end: 100 };
848 let id = mgr
849 .create_alert_hook(cond, "loss exploded".to_string(), AlertSeverity::Critical)
850 .expect("should succeed");
851 let hook = mgr.get_hook(id).expect("should exist");
852 assert!(matches!(hook.trigger, HookTrigger::Conditional(_)));
853 }
854
855 #[test]
858 fn test_hook_builder_basic() {
859 let cfg = HookBuilder::new("my_hook")
860 .trigger(HookTrigger::EveryNSteps(10))
861 .action(HookAction::TrackGradients)
862 .max_executions(50)
863 .layer_patterns(vec!["norm".to_string()])
864 .enabled(true)
865 .build();
866
867 assert_eq!(cfg.name, "my_hook");
868 assert!(matches!(cfg.trigger, HookTrigger::EveryNSteps(10)));
869 assert_eq!(cfg.max_executions, Some(50));
870 assert!(cfg.enabled);
871 }
872
873 #[test]
876 fn test_hook_trigger_variants() {
877 let triggers: Vec<String> = vec![
878 format!("{:?}", HookTrigger::EveryForward),
879 format!("{:?}", HookTrigger::EveryBackward),
880 format!("{:?}", HookTrigger::EveryNSteps(5)),
881 format!("{:?}", HookTrigger::Once),
882 format!("{:?}", HookTrigger::LayerSpecific(vec![])),
883 ];
884 for t in &triggers {
885 assert!(!t.is_empty());
886 }
887 }
888
889 #[test]
890 fn test_hook_action_variants() {
891 let actions: Vec<String> = vec![
892 format!("{:?}", HookAction::InspectTensor),
893 format!("{:?}", HookAction::TrackGradients),
894 format!("{:?}", HookAction::RecordActivations),
895 format!(
896 "{:?}",
897 HookAction::SaveSnapshot {
898 path: "/tmp".to_string()
899 }
900 ),
901 format!(
902 "{:?}",
903 HookAction::Alert {
904 message: "x".to_string(),
905 severity: AlertSeverity::Info
906 }
907 ),
908 format!(
909 "{:?}",
910 HookAction::CustomCallback {
911 name: "cb".to_string()
912 }
913 ),
914 format!("{:?}", HookAction::PauseTraining),
915 ];
916 for a in &actions {
917 assert!(!a.is_empty());
918 }
919 }
920
921 #[test]
922 fn test_alert_severity_variants() {
923 let severities = [
924 AlertSeverity::Info,
925 AlertSeverity::Warning,
926 AlertSeverity::Critical,
927 ];
928 for s in &severities {
929 assert!(!format!("{:?}", s).is_empty());
930 }
931 }
932
933 #[test]
934 fn test_comparison_variants() {
935 let comps = [
936 Comparison::Greater,
937 Comparison::Less,
938 Comparison::Equal,
939 Comparison::GreaterEqual,
940 Comparison::LessEqual,
941 ];
942 for c in &comps {
943 assert!(!format!("{:?}", c).is_empty());
944 }
945 }
946
947 #[test]
948 fn test_hook_stats_fields() {
949 let id = Uuid::new_v4();
950 let stats = HookStats {
951 hook_id: id,
952 hook_name: "perf_hook".to_string(),
953 total_executions: 100,
954 last_execution_step: Some(99),
955 total_execution_time_ms: 500.0,
956 avg_execution_time_ms: 5.0,
957 errors: 2,
958 };
959 assert_eq!(stats.total_executions, 100);
960 assert_eq!(stats.errors, 2);
961 assert_eq!(stats.last_execution_step, Some(99));
962 }
963}