1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use uuid::Uuid;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub enum HookTrigger {
11 EveryForward,
13 EveryBackward,
15 EveryNSteps(usize),
17 Conditional(HookCondition),
19 Once,
21 LayerSpecific(Vec<String>),
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum HookCondition {
28 LossThreshold {
30 threshold: f64,
31 comparison: Comparison,
32 },
33 GradientNormThreshold {
35 threshold: f64,
36 comparison: Comparison,
37 },
38 MemoryThreshold { threshold_mb: f64 },
40 StepRange { start: usize, end: usize },
42 Custom(String),
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum Comparison {
48 Greater,
49 Less,
50 Equal,
51 GreaterEqual,
52 LessEqual,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum HookAction {
58 InspectTensor,
60 TrackGradients,
62 RecordActivations,
64 SaveSnapshot { path: String },
66 Alert {
68 message: String,
69 severity: AlertSeverity,
70 },
71 CustomCallback { name: String },
73 PauseTraining,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub enum AlertSeverity {
79 Info,
80 Warning,
81 Critical,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct HookConfig {
87 pub id: Uuid,
88 pub name: String,
89 pub trigger: HookTrigger,
90 pub actions: Vec<HookAction>,
91 pub enabled: bool,
92 pub max_executions: Option<usize>,
93 pub layer_patterns: Vec<String>, }
95
96#[derive(Debug)]
98pub struct HookContext {
99 pub step: usize,
100 pub layer_name: String,
101 pub tensor_shape: Vec<usize>,
102 pub is_forward: bool,
103 pub metadata: HashMap<String, String>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct HookStats {
109 pub hook_id: Uuid,
110 pub hook_name: String,
111 pub total_executions: usize,
112 pub last_execution_step: Option<usize>,
113 pub total_execution_time_ms: f64,
114 pub avg_execution_time_ms: f64,
115 pub errors: usize,
116}
117
118#[derive(Debug)]
120pub enum HookResult {
121 Success,
122 Error(String),
123 Skipped(String),
124}
125
126pub type HookCallback = Box<dyn Fn(&HookContext, &[u8]) -> Result<()> + Send + Sync>;
128
129pub struct HookManager {
131 hooks: HashMap<Uuid, HookConfig>,
132 hook_stats: HashMap<Uuid, HookStats>,
133 callbacks: HashMap<String, HookCallback>,
134 execution_count: HashMap<Uuid, usize>,
135 global_step: usize,
136 enabled: bool,
137}
138
139impl std::fmt::Debug for HookManager {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.debug_struct("HookManager")
142 .field("hooks", &self.hooks)
143 .field("hook_stats", &self.hook_stats)
144 .field("execution_count", &self.execution_count)
145 .field("global_step", &self.global_step)
146 .field("enabled", &self.enabled)
147 .field("callbacks", &format!("{} callbacks", self.callbacks.len()))
148 .finish()
149 }
150}
151
152impl HookManager {
153 pub fn new() -> Self {
155 Self {
156 hooks: HashMap::new(),
157 hook_stats: HashMap::new(),
158 callbacks: HashMap::new(),
159 execution_count: HashMap::new(),
160 global_step: 0,
161 enabled: true,
162 }
163 }
164
165 pub fn register_hook(&mut self, config: HookConfig) -> Result<Uuid> {
167 let hook_id = config.id;
168
169 self.hook_stats.insert(
171 hook_id,
172 HookStats {
173 hook_id,
174 hook_name: config.name.clone(),
175 total_executions: 0,
176 last_execution_step: None,
177 total_execution_time_ms: 0.0,
178 avg_execution_time_ms: 0.0,
179 errors: 0,
180 },
181 );
182
183 self.execution_count.insert(hook_id, 0);
184 self.hooks.insert(hook_id, config);
185
186 tracing::debug!("Registered hook {}", hook_id);
187 Ok(hook_id)
188 }
189
190 pub fn register_callback(&mut self, name: String, callback: HookCallback) {
192 self.callbacks.insert(name, callback);
193 }
194
195 pub fn remove_hook(&mut self, hook_id: Uuid) -> Option<HookConfig> {
197 self.hook_stats.remove(&hook_id);
198 self.execution_count.remove(&hook_id);
199 self.hooks.remove(&hook_id)
200 }
201
202 pub fn set_hook_enabled(&mut self, hook_id: Uuid, enabled: bool) -> Result<()> {
204 if let Some(hook) = self.hooks.get_mut(&hook_id) {
205 hook.enabled = enabled;
206 Ok(())
207 } else {
208 Err(anyhow::anyhow!("Hook {} not found", hook_id))
209 }
210 }
211
212 pub fn set_enabled(&mut self, enabled: bool) {
214 self.enabled = enabled;
215 }
216
217 pub fn set_step(&mut self, step: usize) {
219 self.global_step = step;
220 }
221
222 pub fn execute_hooks<T>(
224 &mut self,
225 layer_name: &str,
226 tensor_data: &[T],
227 tensor_shape: &[usize],
228 is_forward: bool,
229 metadata: Option<HashMap<String, String>>,
230 ) -> Vec<(Uuid, HookResult)>
231 where
232 T: Clone + 'static,
233 {
234 if !self.enabled {
235 return Vec::new();
236 }
237
238 let context = HookContext {
239 step: self.global_step,
240 layer_name: layer_name.to_string(),
241 tensor_shape: tensor_shape.to_vec(),
242 is_forward,
243 metadata: metadata.unwrap_or_default(),
244 };
245
246 let mut results = Vec::new();
247
248 let tensor_bytes = unsafe {
250 std::slice::from_raw_parts(
251 tensor_data.as_ptr() as *const u8,
252 tensor_data.len() * std::mem::size_of::<T>(),
253 )
254 };
255
256 let hooks_to_execute: Vec<(Uuid, HookConfig)> =
258 self.hooks.iter().map(|(id, config)| (*id, config.clone())).collect();
259
260 for (hook_id, hook_config) in hooks_to_execute {
261 if !hook_config.enabled {
262 continue;
263 }
264
265 if let Some(should_execute) = self.should_execute_hook(&hook_config, &context) {
267 if !should_execute {
268 results.push((
269 hook_id,
270 HookResult::Skipped("Condition not met".to_string()),
271 ));
272 continue;
273 }
274 }
275
276 let current_count = self.execution_count.get(&hook_id).copied().unwrap_or(0);
278 if let Some(max_executions) = hook_config.max_executions {
279 if current_count >= max_executions {
280 results.push((
281 hook_id,
282 HookResult::Skipped("Max executions reached".to_string()),
283 ));
284 continue;
285 }
286 }
287
288 let start_time = std::time::Instant::now();
290 let result = self.execute_single_hook(&hook_config, &context, tensor_bytes);
291 let execution_time = start_time.elapsed().as_millis() as f64;
292
293 if let Some(stats) = self.hook_stats.get_mut(&hook_id) {
295 stats.total_executions += 1;
296 stats.last_execution_step = Some(self.global_step);
297 stats.total_execution_time_ms += execution_time;
298 stats.avg_execution_time_ms =
299 stats.total_execution_time_ms / stats.total_executions as f64;
300
301 if matches!(result, HookResult::Error(_)) {
302 stats.errors += 1;
303 }
304 }
305
306 if let Some(count) = self.execution_count.get_mut(&hook_id) {
308 *count += 1;
309 }
310
311 results.push((hook_id, result));
312 }
313
314 results
315 }
316
317 pub fn get_hook(&self, hook_id: Uuid) -> Option<&HookConfig> {
319 self.hooks.get(&hook_id)
320 }
321
322 pub fn get_all_hooks(&self) -> Vec<&HookConfig> {
324 self.hooks.values().collect()
325 }
326
327 pub fn get_hook_stats(&self, hook_id: Uuid) -> Option<&HookStats> {
329 self.hook_stats.get(&hook_id)
330 }
331
332 pub fn get_all_stats(&self) -> Vec<&HookStats> {
334 self.hook_stats.values().collect()
335 }
336
337 pub fn clear_hooks(&mut self) {
339 self.hooks.clear();
340 self.hook_stats.clear();
341 self.execution_count.clear();
342 self.callbacks.clear();
343 }
344
345 pub fn create_tensor_inspection_hook(&mut self, layer_patterns: Vec<String>) -> Result<Uuid> {
347 let config = HookConfig {
348 id: Uuid::new_v4(),
349 name: "Tensor Inspector".to_string(),
350 trigger: HookTrigger::EveryForward,
351 actions: vec![HookAction::InspectTensor],
352 enabled: true,
353 max_executions: None,
354 layer_patterns,
355 };
356
357 self.register_hook(config)
358 }
359
360 pub fn create_gradient_tracking_hook(&mut self, layer_patterns: Vec<String>) -> Result<Uuid> {
362 let config = HookConfig {
363 id: Uuid::new_v4(),
364 name: "Gradient Tracker".to_string(),
365 trigger: HookTrigger::EveryBackward,
366 actions: vec![HookAction::TrackGradients],
367 enabled: true,
368 max_executions: None,
369 layer_patterns,
370 };
371
372 self.register_hook(config)
373 }
374
375 pub fn create_alert_hook(
377 &mut self,
378 condition: HookCondition,
379 message: String,
380 severity: AlertSeverity,
381 ) -> Result<Uuid> {
382 let config = HookConfig {
383 id: Uuid::new_v4(),
384 name: "Alert Hook".to_string(),
385 trigger: HookTrigger::Conditional(condition),
386 actions: vec![HookAction::Alert { message, severity }],
387 enabled: true,
388 max_executions: None,
389 layer_patterns: vec![".*".to_string()], };
391
392 self.register_hook(config)
393 }
394
395 fn should_execute_hook(&self, hook: &HookConfig, context: &HookContext) -> Option<bool> {
398 if !hook.layer_patterns.is_empty() {
400 let matches_pattern = hook.layer_patterns.iter().any(|pattern| {
401 regex::Regex::new(pattern)
402 .map(|re| re.is_match(&context.layer_name))
403 .unwrap_or(false)
404 });
405
406 if !matches_pattern {
407 return Some(false);
408 }
409 }
410
411 match &hook.trigger {
412 HookTrigger::EveryForward => Some(context.is_forward),
413 HookTrigger::EveryBackward => Some(!context.is_forward),
414 HookTrigger::EveryNSteps(n) => Some(context.step % n == 0),
415 HookTrigger::Conditional(condition) => {
416 Some(self.evaluate_condition(condition, context))
417 },
418 HookTrigger::Once => {
419 let count = self.execution_count.get(&hook.id).copied().unwrap_or(0);
420 Some(count == 0)
421 },
422 HookTrigger::LayerSpecific(layers) => Some(layers.contains(&context.layer_name)),
423 }
424 }
425
426 fn evaluate_condition(&self, condition: &HookCondition, context: &HookContext) -> bool {
427 match condition {
428 HookCondition::StepRange { start, end } => {
429 context.step >= *start && context.step <= *end
430 },
431 HookCondition::Custom(name) => {
432 context.metadata.contains_key(name)
435 },
436 _ => true,
438 }
439 }
440
441 fn execute_single_hook(
442 &mut self,
443 hook: &HookConfig,
444 context: &HookContext,
445 tensor_data: &[u8],
446 ) -> HookResult {
447 for action in &hook.actions {
448 match self.execute_action(action, context, tensor_data) {
449 Ok(()) => continue,
450 Err(e) => return HookResult::Error(e.to_string()),
451 }
452 }
453 HookResult::Success
454 }
455
456 fn execute_action(
457 &mut self,
458 action: &HookAction,
459 context: &HookContext,
460 tensor_data: &[u8],
461 ) -> Result<()> {
462 match action {
463 HookAction::InspectTensor => {
464 tracing::debug!(
465 "Inspecting tensor in layer '{}' at step {}",
466 context.layer_name,
467 context.step
468 );
469 Ok(())
471 },
472 HookAction::TrackGradients => {
473 tracing::debug!(
474 "Tracking gradients in layer '{}' at step {}",
475 context.layer_name,
476 context.step
477 );
478 Ok(())
480 },
481 HookAction::RecordActivations => {
482 tracing::debug!(
483 "Recording activations in layer '{}' at step {}",
484 context.layer_name,
485 context.step
486 );
487 Ok(())
489 },
490 HookAction::SaveSnapshot { path } => {
491 let file_path =
492 format!("{}_{}_step_{}.bin", path, context.layer_name, context.step);
493 std::fs::write(&file_path, tensor_data)?;
494 tracing::info!("Saved tensor snapshot to {}", file_path);
495 Ok(())
496 },
497 HookAction::Alert { message, severity } => {
498 match severity {
499 AlertSeverity::Info => tracing::info!("Hook Alert: {}", message),
500 AlertSeverity::Warning => tracing::warn!("Hook Alert: {}", message),
501 AlertSeverity::Critical => tracing::error!("Hook Alert: {}", message),
502 }
503 Ok(())
504 },
505 HookAction::CustomCallback { name } => {
506 if let Some(callback) = self.callbacks.get(name) {
507 callback(context, tensor_data)?;
508 } else {
509 return Err(anyhow::anyhow!("Callback '{}' not found", name));
510 }
511 Ok(())
512 },
513 HookAction::PauseTraining => {
514 tracing::warn!(
515 "Training paused by hook at step {} in layer '{}'",
516 context.step,
517 context.layer_name
518 );
519 Ok(())
521 },
522 }
523 }
524}
525
526impl Default for HookManager {
527 fn default() -> Self {
528 Self::new()
529 }
530}
531
532pub struct HookBuilder {
534 config: HookConfig,
535}
536
537impl HookBuilder {
538 pub fn new(name: &str) -> Self {
539 Self {
540 config: HookConfig {
541 id: Uuid::new_v4(),
542 name: name.to_string(),
543 trigger: HookTrigger::EveryForward,
544 actions: Vec::new(),
545 enabled: true,
546 max_executions: None,
547 layer_patterns: Vec::new(),
548 },
549 }
550 }
551
552 pub fn trigger(mut self, trigger: HookTrigger) -> Self {
553 self.config.trigger = trigger;
554 self
555 }
556
557 pub fn action(mut self, action: HookAction) -> Self {
558 self.config.actions.push(action);
559 self
560 }
561
562 pub fn actions(mut self, actions: Vec<HookAction>) -> Self {
563 self.config.actions = actions;
564 self
565 }
566
567 pub fn max_executions(mut self, max: usize) -> Self {
568 self.config.max_executions = Some(max);
569 self
570 }
571
572 pub fn layer_patterns(mut self, patterns: Vec<String>) -> Self {
573 self.config.layer_patterns = patterns;
574 self
575 }
576
577 pub fn enabled(mut self, enabled: bool) -> Self {
578 self.config.enabled = enabled;
579 self
580 }
581
582 pub fn build(self) -> HookConfig {
583 self.config
584 }
585}
586
587#[macro_export]
589macro_rules! tensor_hook {
590 ($name:expr, $patterns:expr) => {
591 HookBuilder::new($name)
592 .trigger(HookTrigger::EveryForward)
593 .action(HookAction::InspectTensor)
594 .layer_patterns($patterns)
595 .build()
596 };
597}
598
599#[macro_export]
600macro_rules! gradient_hook {
601 ($name:expr, $patterns:expr) => {
602 HookBuilder::new($name)
603 .trigger(HookTrigger::EveryBackward)
604 .action(HookAction::TrackGradients)
605 .layer_patterns($patterns)
606 .build()
607 };
608}
609
610#[macro_export]
611macro_rules! alert_hook {
612 ($condition:expr, $message:expr, $severity:expr) => {
613 HookBuilder::new("Alert Hook")
614 .trigger(HookTrigger::Conditional($condition))
615 .action(HookAction::Alert {
616 message: $message.to_string(),
617 severity: $severity,
618 })
619 .build()
620 };
621}