1use std::sync::Arc;
10
11use serde_json::Value;
12use tokio::time::{Duration, timeout};
13
14use crate::BoxFuture;
15use crate::agents::error::AgentError;
16
17#[derive(Debug, Clone)]
23pub struct PreToolUseContext {
24 pub tool_name: String,
26 pub arguments: Value,
28}
29
30#[derive(Debug, Clone)]
32pub struct PostToolUseContext {
33 pub tool_name: String,
35 pub arguments: Value,
37 pub output: Value,
39}
40
41#[derive(Debug, Clone)]
43pub struct PostToolUseFailureContext {
44 pub tool_name: String,
46 pub arguments: Value,
48 pub error: String,
50}
51
52#[derive(Debug, Clone)]
54pub struct NotificationContext {
55 pub message: String,
57 pub level: String,
59}
60
61#[derive(Debug, Clone)]
63pub struct SubagentStartContext {
64 pub agent_name: String,
66 pub initial_message: Option<String>,
68}
69
70#[derive(Debug, Clone)]
72pub struct SubagentStopContext {
73 pub agent_name: String,
75 pub reason: String,
77}
78
79#[derive(Debug, Clone)]
81pub struct PreCompactContext {
82 pub message_count: usize,
84 pub token_count: u64,
86}
87
88#[derive(Debug, Clone)]
90pub struct PostCompactContext {
91 pub message_count: usize,
93 pub token_count: u64,
95}
96
97#[derive(Debug, Clone)]
99pub struct SessionStartContext {
100 pub session_id: String,
102 pub resumed: bool,
104}
105
106#[derive(Debug, Clone)]
108pub struct SessionEndContext {
109 pub session_id: String,
111 pub reason: String,
113}
114
115#[derive(Debug, Clone)]
121#[non_exhaustive]
122pub enum HookResult {
123 Continue,
125 Abort(String),
127}
128
129#[derive(Debug, Clone)]
135pub struct HookMatcher {
136 pub tool_name_pattern: Option<String>,
139 pub timeout: Duration,
141}
142
143impl Default for HookMatcher {
144 fn default() -> Self {
145 Self {
146 tool_name_pattern: None,
147 timeout: Duration::from_secs(30),
148 }
149 }
150}
151
152impl HookMatcher {
153 #[must_use]
155 pub fn matches_tool(&self, tool_name: &str) -> bool {
156 self.tool_name_pattern
157 .as_ref()
158 .is_none_or(|pattern| glob_match(pattern, tool_name))
159 }
160}
161
162fn glob_match(pattern: &str, input: &str) -> bool {
164 if pattern == "*" {
165 return true;
166 }
167 let parts: Vec<&str> = pattern.split('*').collect();
168 if parts.len() == 1 {
169 return pattern == input;
170 }
171 let mut remaining = input;
172 for (i, part) in parts.iter().enumerate() {
173 if part.is_empty() {
174 continue;
175 }
176 if i == 0 {
177 if !remaining.starts_with(part) {
178 return false;
179 }
180 remaining = &remaining[part.len()..];
181 } else if let Some(pos) = remaining.find(part) {
182 remaining = &remaining[pos + part.len()..];
183 } else {
184 return false;
185 }
186 }
187 if !pattern.ends_with('*') && !remaining.is_empty() {
188 return false;
189 }
190 true
191}
192
193type PreToolUseFn = Arc<dyn Fn(PreToolUseContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
198type PostToolUseFn =
199 Arc<dyn Fn(PostToolUseContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
200type PostToolUseFailureFn =
201 Arc<dyn Fn(PostToolUseFailureContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
202type NotificationFn =
203 Arc<dyn Fn(NotificationContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
204type SubagentStartFn =
205 Arc<dyn Fn(SubagentStartContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
206type SubagentStopFn =
207 Arc<dyn Fn(SubagentStopContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
208type PreCompactFn = Arc<dyn Fn(PreCompactContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
209type PostCompactFn =
210 Arc<dyn Fn(PostCompactContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
211type SessionStartFn =
212 Arc<dyn Fn(SessionStartContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
213type SessionEndFn = Arc<dyn Fn(SessionEndContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
214
215enum HookEntry {
216 PreToolUse(HookMatcher, PreToolUseFn),
217 PostToolUse(HookMatcher, PostToolUseFn),
218 PostToolUseFailure(HookMatcher, PostToolUseFailureFn),
219 Notification(HookMatcher, NotificationFn),
220 SubagentStart(HookMatcher, SubagentStartFn),
221 SubagentStop(HookMatcher, SubagentStopFn),
222 PreCompact(HookMatcher, PreCompactFn),
223 PostCompact(HookMatcher, PostCompactFn),
224 SessionStart(HookMatcher, SessionStartFn),
225 SessionEnd(HookMatcher, SessionEndFn),
226}
227
228impl std::fmt::Debug for HookEntry {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 match self {
231 Self::PreToolUse(m, _) => write!(f, "PreToolUse({m:?})"),
232 Self::PostToolUse(m, _) => write!(f, "PostToolUse({m:?})"),
233 Self::PostToolUseFailure(m, _) => write!(f, "PostToolUseFailure({m:?})"),
234 Self::Notification(m, _) => write!(f, "Notification({m:?})"),
235 Self::SubagentStart(m, _) => write!(f, "SubagentStart({m:?})"),
236 Self::SubagentStop(m, _) => write!(f, "SubagentStop({m:?})"),
237 Self::PreCompact(m, _) => write!(f, "PreCompact({m:?})"),
238 Self::PostCompact(m, _) => write!(f, "PostCompact({m:?})"),
239 Self::SessionStart(m, _) => write!(f, "SessionStart({m:?})"),
240 Self::SessionEnd(m, _) => write!(f, "SessionEnd({m:?})"),
241 }
242 }
243}
244
245#[derive(Debug, Default)]
251pub struct HookRegistry {
252 hooks: Vec<HookEntry>,
253}
254
255impl HookRegistry {
256 #[must_use]
258 pub fn new() -> Self {
259 Self::default()
260 }
261
262 pub fn on_pre_tool_use<F>(&mut self, matcher: HookMatcher, f: F)
266 where
267 F: Fn(PreToolUseContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
268 {
269 self.hooks.push(HookEntry::PreToolUse(matcher, Arc::new(f)));
270 }
271
272 pub fn on_post_tool_use<F>(&mut self, matcher: HookMatcher, f: F)
274 where
275 F: Fn(PostToolUseContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
276 {
277 self.hooks
278 .push(HookEntry::PostToolUse(matcher, Arc::new(f)));
279 }
280
281 pub fn on_post_tool_use_failure<F>(&mut self, matcher: HookMatcher, f: F)
283 where
284 F: Fn(PostToolUseFailureContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
285 {
286 self.hooks
287 .push(HookEntry::PostToolUseFailure(matcher, Arc::new(f)));
288 }
289
290 pub fn on_notification<F>(&mut self, matcher: HookMatcher, f: F)
292 where
293 F: Fn(NotificationContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
294 {
295 self.hooks
296 .push(HookEntry::Notification(matcher, Arc::new(f)));
297 }
298
299 pub fn on_subagent_start<F>(&mut self, matcher: HookMatcher, f: F)
301 where
302 F: Fn(SubagentStartContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
303 {
304 self.hooks
305 .push(HookEntry::SubagentStart(matcher, Arc::new(f)));
306 }
307
308 pub fn on_subagent_stop<F>(&mut self, matcher: HookMatcher, f: F)
310 where
311 F: Fn(SubagentStopContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
312 {
313 self.hooks
314 .push(HookEntry::SubagentStop(matcher, Arc::new(f)));
315 }
316
317 pub fn on_pre_compact<F>(&mut self, matcher: HookMatcher, f: F)
319 where
320 F: Fn(PreCompactContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
321 {
322 self.hooks.push(HookEntry::PreCompact(matcher, Arc::new(f)));
323 }
324
325 pub fn on_post_compact<F>(&mut self, matcher: HookMatcher, f: F)
327 where
328 F: Fn(PostCompactContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
329 {
330 self.hooks
331 .push(HookEntry::PostCompact(matcher, Arc::new(f)));
332 }
333
334 pub fn on_session_start<F>(&mut self, matcher: HookMatcher, f: F)
336 where
337 F: Fn(SessionStartContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
338 {
339 self.hooks
340 .push(HookEntry::SessionStart(matcher, Arc::new(f)));
341 }
342
343 pub fn on_session_end<F>(&mut self, matcher: HookMatcher, f: F)
345 where
346 F: Fn(SessionEndContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
347 {
348 self.hooks.push(HookEntry::SessionEnd(matcher, Arc::new(f)));
349 }
350
351 pub async fn run_pre_tool_use(&self, ctx: PreToolUseContext) -> Result<HookResult, AgentError> {
358 for entry in &self.hooks {
359 if let HookEntry::PreToolUse(matcher, f) = entry {
360 if !matcher.matches_tool(&ctx.tool_name) {
361 continue;
362 }
363 if let HookResult::Abort(msg) =
364 run_with_timeout(f(ctx.clone()), matcher.timeout).await
365 {
366 return Ok(HookResult::Abort(msg));
367 }
368 }
369 }
370 Ok(HookResult::Continue)
371 }
372
373 pub async fn run_post_tool_use(
375 &self,
376 ctx: PostToolUseContext,
377 ) -> Result<HookResult, AgentError> {
378 for entry in &self.hooks {
379 if let HookEntry::PostToolUse(matcher, f) = entry {
380 if !matcher.matches_tool(&ctx.tool_name) {
381 continue;
382 }
383 if let HookResult::Abort(msg) =
384 run_with_timeout(f(ctx.clone()), matcher.timeout).await
385 {
386 return Ok(HookResult::Abort(msg));
387 }
388 }
389 }
390 Ok(HookResult::Continue)
391 }
392
393 pub async fn run_post_tool_use_failure(
395 &self,
396 ctx: PostToolUseFailureContext,
397 ) -> Result<HookResult, AgentError> {
398 for entry in &self.hooks {
399 if let HookEntry::PostToolUseFailure(matcher, f) = entry {
400 if !matcher.matches_tool(&ctx.tool_name) {
401 continue;
402 }
403 if let HookResult::Abort(msg) =
404 run_with_timeout(f(ctx.clone()), matcher.timeout).await
405 {
406 return Ok(HookResult::Abort(msg));
407 }
408 }
409 }
410 Ok(HookResult::Continue)
411 }
412
413 pub async fn run_notification(
415 &self,
416 ctx: NotificationContext,
417 ) -> Result<HookResult, AgentError> {
418 for entry in &self.hooks {
419 if let HookEntry::Notification(matcher, f) = entry
420 && let HookResult::Abort(msg) =
421 run_with_timeout(f(ctx.clone()), matcher.timeout).await
422 {
423 return Ok(HookResult::Abort(msg));
424 }
425 }
426 Ok(HookResult::Continue)
427 }
428
429 pub async fn run_session_start(
431 &self,
432 ctx: SessionStartContext,
433 ) -> Result<HookResult, AgentError> {
434 for entry in &self.hooks {
435 if let HookEntry::SessionStart(matcher, f) = entry
436 && let HookResult::Abort(msg) =
437 run_with_timeout(f(ctx.clone()), matcher.timeout).await
438 {
439 return Ok(HookResult::Abort(msg));
440 }
441 }
442 Ok(HookResult::Continue)
443 }
444
445 pub async fn run_session_end(&self, ctx: SessionEndContext) -> Result<HookResult, AgentError> {
447 for entry in &self.hooks {
448 if let HookEntry::SessionEnd(matcher, f) = entry
449 && let HookResult::Abort(msg) =
450 run_with_timeout(f(ctx.clone()), matcher.timeout).await
451 {
452 return Ok(HookResult::Abort(msg));
453 }
454 }
455 Ok(HookResult::Continue)
456 }
457
458 pub async fn run_subagent_start(
460 &self,
461 ctx: SubagentStartContext,
462 ) -> Result<HookResult, AgentError> {
463 for entry in &self.hooks {
464 if let HookEntry::SubagentStart(matcher, f) = entry
465 && let HookResult::Abort(msg) =
466 run_with_timeout(f(ctx.clone()), matcher.timeout).await
467 {
468 return Ok(HookResult::Abort(msg));
469 }
470 }
471 Ok(HookResult::Continue)
472 }
473
474 pub async fn run_subagent_stop(
476 &self,
477 ctx: SubagentStopContext,
478 ) -> Result<HookResult, AgentError> {
479 for entry in &self.hooks {
480 if let HookEntry::SubagentStop(matcher, f) = entry
481 && let HookResult::Abort(msg) =
482 run_with_timeout(f(ctx.clone()), matcher.timeout).await
483 {
484 return Ok(HookResult::Abort(msg));
485 }
486 }
487 Ok(HookResult::Continue)
488 }
489
490 pub async fn run_pre_compact(&self, ctx: PreCompactContext) -> Result<HookResult, AgentError> {
492 for entry in &self.hooks {
493 if let HookEntry::PreCompact(matcher, f) = entry
494 && let HookResult::Abort(msg) =
495 run_with_timeout(f(ctx.clone()), matcher.timeout).await
496 {
497 return Ok(HookResult::Abort(msg));
498 }
499 }
500 Ok(HookResult::Continue)
501 }
502
503 pub async fn run_post_compact(
505 &self,
506 ctx: PostCompactContext,
507 ) -> Result<HookResult, AgentError> {
508 for entry in &self.hooks {
509 if let HookEntry::PostCompact(matcher, f) = entry
510 && let HookResult::Abort(msg) =
511 run_with_timeout(f(ctx.clone()), matcher.timeout).await
512 {
513 return Ok(HookResult::Abort(msg));
514 }
515 }
516 Ok(HookResult::Continue)
517 }
518}
519
520async fn run_with_timeout(fut: BoxFuture<'static, HookResult>, duration: Duration) -> HookResult {
522 match timeout(duration, fut).await {
523 Ok(result) => result,
524 Err(_elapsed) => {
525 tracing::warn!(?duration, "Hook timed out — skipping");
526 HookResult::Continue
527 }
528 }
529}
530
531#[cfg(test)]
532#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
533mod tests {
534 use super::*;
535
536 #[tokio::test]
537 async fn test_pre_tool_use_abort() {
538 let mut registry = HookRegistry::new();
539 registry.on_pre_tool_use(HookMatcher::default(), |_ctx| {
540 Box::pin(async { HookResult::Abort("blocked".to_string()) })
541 });
542 let ctx = PreToolUseContext {
543 tool_name: "read_file".to_string(),
544 arguments: serde_json::json!({}),
545 };
546 let result = registry.run_pre_tool_use(ctx).await.unwrap();
547 assert!(matches!(result, HookResult::Abort(_)));
548 }
549
550 #[tokio::test]
551 async fn test_tool_name_pattern_no_match() {
552 let mut registry = HookRegistry::new();
553 registry.on_pre_tool_use(
554 HookMatcher {
555 tool_name_pattern: Some("write_*".to_string()),
556 timeout: Duration::from_secs(5),
557 },
558 |_ctx| Box::pin(async { HookResult::Abort("blocked".to_string()) }),
559 );
560 let ctx = PreToolUseContext {
561 tool_name: "read_file".to_string(),
562 arguments: serde_json::json!({}),
563 };
564 let result = registry.run_pre_tool_use(ctx).await.unwrap();
565 assert!(matches!(result, HookResult::Continue));
566 }
567
568 #[tokio::test]
569 async fn test_timeout_skips_hook() {
570 let mut registry = HookRegistry::new();
571 registry.on_pre_tool_use(
572 HookMatcher {
573 tool_name_pattern: None,
574 timeout: Duration::from_millis(10),
575 },
576 |_ctx| {
577 Box::pin(async {
578 tokio::time::sleep(Duration::from_secs(10)).await;
579 HookResult::Abort("late abort".to_string())
580 })
581 },
582 );
583 let ctx = PreToolUseContext {
584 tool_name: "read_file".to_string(),
585 arguments: serde_json::json!({}),
586 };
587 let result = registry.run_pre_tool_use(ctx).await.unwrap();
589 assert!(matches!(result, HookResult::Continue));
590 }
591
592 #[test]
593 fn test_glob_match() {
594 assert!(glob_match("*", "anything"));
595 assert!(glob_match("write_*", "write_file"));
596 assert!(!glob_match("write_*", "read_file"));
597 assert!(glob_match("*_file", "read_file"));
598 assert!(glob_match("exact", "exact"));
599 assert!(!glob_match("exact", "not_exact"));
600 }
601}