Skip to main content

synwire_core/agents/
hooks.rs

1//! Lifecycle hooks for agent execution.
2//!
3//! Hooks are short-lived async callbacks invoked at well-defined points in the
4//! agent lifecycle.  Each hook type carries its own payload.  Hooks are matched
5//! by an optional tool-name glob pattern and executed with an enforced timeout;
6//! hooks that exceed their timeout are skipped with a warning rather than
7//! failing the agent.
8
9use std::sync::Arc;
10
11use serde_json::Value;
12use tokio::time::{Duration, timeout};
13
14use crate::BoxFuture;
15use crate::agents::error::AgentError;
16
17// ---------------------------------------------------------------------------
18// Payload types
19// ---------------------------------------------------------------------------
20
21/// Context passed to pre-tool-use hooks.
22#[derive(Debug, Clone)]
23pub struct PreToolUseContext {
24    /// Tool name.
25    pub tool_name: String,
26    /// Tool arguments.
27    pub arguments: Value,
28}
29
30/// Context passed to post-tool-use (success) hooks.
31#[derive(Debug, Clone)]
32pub struct PostToolUseContext {
33    /// Tool name.
34    pub tool_name: String,
35    /// Tool arguments.
36    pub arguments: Value,
37    /// Tool output.
38    pub output: Value,
39}
40
41/// Context passed to post-tool-use (failure) hooks.
42#[derive(Debug, Clone)]
43pub struct PostToolUseFailureContext {
44    /// Tool name.
45    pub tool_name: String,
46    /// Tool arguments.
47    pub arguments: Value,
48    /// Error message.
49    pub error: String,
50}
51
52/// Context passed to notification hooks.
53#[derive(Debug, Clone)]
54pub struct NotificationContext {
55    /// Notification message.
56    pub message: String,
57    /// Severity level.
58    pub level: String,
59}
60
61/// Context passed to subagent start hooks.
62#[derive(Debug, Clone)]
63pub struct SubagentStartContext {
64    /// Subagent name.
65    pub agent_name: String,
66    /// Initial message.
67    pub initial_message: Option<String>,
68}
69
70/// Context passed to subagent stop hooks.
71#[derive(Debug, Clone)]
72pub struct SubagentStopContext {
73    /// Subagent name.
74    pub agent_name: String,
75    /// Termination reason.
76    pub reason: String,
77}
78
79/// Context passed to pre-compact hooks.
80#[derive(Debug, Clone)]
81pub struct PreCompactContext {
82    /// Message count before compaction.
83    pub message_count: usize,
84    /// Token count estimate.
85    pub token_count: u64,
86}
87
88/// Context passed to post-compact hooks.
89#[derive(Debug, Clone)]
90pub struct PostCompactContext {
91    /// Message count after compaction.
92    pub message_count: usize,
93    /// Token count estimate after compaction.
94    pub token_count: u64,
95}
96
97/// Context passed to session start hooks.
98#[derive(Debug, Clone)]
99pub struct SessionStartContext {
100    /// Session ID.
101    pub session_id: String,
102    /// Whether this is a resumed session.
103    pub resumed: bool,
104}
105
106/// Context passed to session end hooks.
107#[derive(Debug, Clone)]
108pub struct SessionEndContext {
109    /// Session ID.
110    pub session_id: String,
111    /// Termination reason.
112    pub reason: String,
113}
114
115// ---------------------------------------------------------------------------
116// Hook result
117// ---------------------------------------------------------------------------
118
119/// Result returned by a hook invocation.
120#[derive(Debug, Clone)]
121#[non_exhaustive]
122pub enum HookResult {
123    /// Continue normal execution.
124    Continue,
125    /// Abort the current operation with a message.
126    Abort(String),
127}
128
129// ---------------------------------------------------------------------------
130// Hook matcher
131// ---------------------------------------------------------------------------
132
133/// Matcher that selects which events a hook applies to.
134#[derive(Debug, Clone)]
135pub struct HookMatcher {
136    /// Optional tool name glob (e.g. `"read_file"`, `"*_file"`).
137    /// `None` matches all.
138    pub tool_name_pattern: Option<String>,
139    /// Timeout for this hook invocation.
140    pub timeout: Duration,
141}
142
143impl Default for HookMatcher {
144    fn default() -> Self {
145        Self {
146            tool_name_pattern: None,
147            timeout: Duration::from_secs(30),
148        }
149    }
150}
151
152impl HookMatcher {
153    /// Returns `true` if the matcher applies to the given tool name.
154    #[must_use]
155    pub fn matches_tool(&self, tool_name: &str) -> bool {
156        self.tool_name_pattern
157            .as_ref()
158            .is_none_or(|pattern| glob_match(pattern, tool_name))
159    }
160}
161
162/// Simple glob: `*` matches any sequence of characters.
163fn glob_match(pattern: &str, input: &str) -> bool {
164    if pattern == "*" {
165        return true;
166    }
167    let parts: Vec<&str> = pattern.split('*').collect();
168    if parts.len() == 1 {
169        return pattern == input;
170    }
171    let mut remaining = input;
172    for (i, part) in parts.iter().enumerate() {
173        if part.is_empty() {
174            continue;
175        }
176        if i == 0 {
177            if !remaining.starts_with(part) {
178                return false;
179            }
180            remaining = &remaining[part.len()..];
181        } else if let Some(pos) = remaining.find(part) {
182            remaining = &remaining[pos + part.len()..];
183        } else {
184            return false;
185        }
186    }
187    if !pattern.ends_with('*') && !remaining.is_empty() {
188        return false;
189    }
190    true
191}
192
193// ---------------------------------------------------------------------------
194// Type-erased hook function wrappers
195// ---------------------------------------------------------------------------
196
197type PreToolUseFn = Arc<dyn Fn(PreToolUseContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
198type PostToolUseFn =
199    Arc<dyn Fn(PostToolUseContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
200type PostToolUseFailureFn =
201    Arc<dyn Fn(PostToolUseFailureContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
202type NotificationFn =
203    Arc<dyn Fn(NotificationContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
204type SubagentStartFn =
205    Arc<dyn Fn(SubagentStartContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
206type SubagentStopFn =
207    Arc<dyn Fn(SubagentStopContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
208type PreCompactFn = Arc<dyn Fn(PreCompactContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
209type PostCompactFn =
210    Arc<dyn Fn(PostCompactContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
211type SessionStartFn =
212    Arc<dyn Fn(SessionStartContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
213type SessionEndFn = Arc<dyn Fn(SessionEndContext) -> BoxFuture<'static, HookResult> + Send + Sync>;
214
215enum HookEntry {
216    PreToolUse(HookMatcher, PreToolUseFn),
217    PostToolUse(HookMatcher, PostToolUseFn),
218    PostToolUseFailure(HookMatcher, PostToolUseFailureFn),
219    Notification(HookMatcher, NotificationFn),
220    SubagentStart(HookMatcher, SubagentStartFn),
221    SubagentStop(HookMatcher, SubagentStopFn),
222    PreCompact(HookMatcher, PreCompactFn),
223    PostCompact(HookMatcher, PostCompactFn),
224    SessionStart(HookMatcher, SessionStartFn),
225    SessionEnd(HookMatcher, SessionEndFn),
226}
227
228impl std::fmt::Debug for HookEntry {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        match self {
231            Self::PreToolUse(m, _) => write!(f, "PreToolUse({m:?})"),
232            Self::PostToolUse(m, _) => write!(f, "PostToolUse({m:?})"),
233            Self::PostToolUseFailure(m, _) => write!(f, "PostToolUseFailure({m:?})"),
234            Self::Notification(m, _) => write!(f, "Notification({m:?})"),
235            Self::SubagentStart(m, _) => write!(f, "SubagentStart({m:?})"),
236            Self::SubagentStop(m, _) => write!(f, "SubagentStop({m:?})"),
237            Self::PreCompact(m, _) => write!(f, "PreCompact({m:?})"),
238            Self::PostCompact(m, _) => write!(f, "PostCompact({m:?})"),
239            Self::SessionStart(m, _) => write!(f, "SessionStart({m:?})"),
240            Self::SessionEnd(m, _) => write!(f, "SessionEnd({m:?})"),
241        }
242    }
243}
244
245// ---------------------------------------------------------------------------
246// HookRegistry
247// ---------------------------------------------------------------------------
248
249/// Registry of lifecycle hooks with typed registration and timeout enforcement.
250#[derive(Debug, Default)]
251pub struct HookRegistry {
252    hooks: Vec<HookEntry>,
253}
254
255impl HookRegistry {
256    /// Create an empty registry.
257    #[must_use]
258    pub fn new() -> Self {
259        Self::default()
260    }
261
262    // --- Registration ---
263
264    /// Register a pre-tool-use hook.
265    pub fn on_pre_tool_use<F>(&mut self, matcher: HookMatcher, f: F)
266    where
267        F: Fn(PreToolUseContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
268    {
269        self.hooks.push(HookEntry::PreToolUse(matcher, Arc::new(f)));
270    }
271
272    /// Register a post-tool-use hook.
273    pub fn on_post_tool_use<F>(&mut self, matcher: HookMatcher, f: F)
274    where
275        F: Fn(PostToolUseContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
276    {
277        self.hooks
278            .push(HookEntry::PostToolUse(matcher, Arc::new(f)));
279    }
280
281    /// Register a post-tool-use failure hook.
282    pub fn on_post_tool_use_failure<F>(&mut self, matcher: HookMatcher, f: F)
283    where
284        F: Fn(PostToolUseFailureContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
285    {
286        self.hooks
287            .push(HookEntry::PostToolUseFailure(matcher, Arc::new(f)));
288    }
289
290    /// Register a notification hook.
291    pub fn on_notification<F>(&mut self, matcher: HookMatcher, f: F)
292    where
293        F: Fn(NotificationContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
294    {
295        self.hooks
296            .push(HookEntry::Notification(matcher, Arc::new(f)));
297    }
298
299    /// Register a subagent start hook.
300    pub fn on_subagent_start<F>(&mut self, matcher: HookMatcher, f: F)
301    where
302        F: Fn(SubagentStartContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
303    {
304        self.hooks
305            .push(HookEntry::SubagentStart(matcher, Arc::new(f)));
306    }
307
308    /// Register a subagent stop hook.
309    pub fn on_subagent_stop<F>(&mut self, matcher: HookMatcher, f: F)
310    where
311        F: Fn(SubagentStopContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
312    {
313        self.hooks
314            .push(HookEntry::SubagentStop(matcher, Arc::new(f)));
315    }
316
317    /// Register a pre-compact hook.
318    pub fn on_pre_compact<F>(&mut self, matcher: HookMatcher, f: F)
319    where
320        F: Fn(PreCompactContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
321    {
322        self.hooks.push(HookEntry::PreCompact(matcher, Arc::new(f)));
323    }
324
325    /// Register a post-compact hook.
326    pub fn on_post_compact<F>(&mut self, matcher: HookMatcher, f: F)
327    where
328        F: Fn(PostCompactContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
329    {
330        self.hooks
331            .push(HookEntry::PostCompact(matcher, Arc::new(f)));
332    }
333
334    /// Register a session start hook.
335    pub fn on_session_start<F>(&mut self, matcher: HookMatcher, f: F)
336    where
337        F: Fn(SessionStartContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
338    {
339        self.hooks
340            .push(HookEntry::SessionStart(matcher, Arc::new(f)));
341    }
342
343    /// Register a session end hook.
344    pub fn on_session_end<F>(&mut self, matcher: HookMatcher, f: F)
345    where
346        F: Fn(SessionEndContext) -> BoxFuture<'static, HookResult> + Send + Sync + 'static,
347    {
348        self.hooks.push(HookEntry::SessionEnd(matcher, Arc::new(f)));
349    }
350
351    // --- Invocation ---
352
353    /// Run all matching pre-tool-use hooks in registration order.
354    ///
355    /// Returns the first `Abort` result encountered, or `Continue` if all pass.
356    /// Hooks that exceed their timeout are skipped with a `warn!` log.
357    pub async fn run_pre_tool_use(&self, ctx: PreToolUseContext) -> Result<HookResult, AgentError> {
358        for entry in &self.hooks {
359            if let HookEntry::PreToolUse(matcher, f) = entry {
360                if !matcher.matches_tool(&ctx.tool_name) {
361                    continue;
362                }
363                if let HookResult::Abort(msg) =
364                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
365                {
366                    return Ok(HookResult::Abort(msg));
367                }
368            }
369        }
370        Ok(HookResult::Continue)
371    }
372
373    /// Run all matching post-tool-use hooks.
374    pub async fn run_post_tool_use(
375        &self,
376        ctx: PostToolUseContext,
377    ) -> Result<HookResult, AgentError> {
378        for entry in &self.hooks {
379            if let HookEntry::PostToolUse(matcher, f) = entry {
380                if !matcher.matches_tool(&ctx.tool_name) {
381                    continue;
382                }
383                if let HookResult::Abort(msg) =
384                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
385                {
386                    return Ok(HookResult::Abort(msg));
387                }
388            }
389        }
390        Ok(HookResult::Continue)
391    }
392
393    /// Run all matching post-tool-use failure hooks.
394    pub async fn run_post_tool_use_failure(
395        &self,
396        ctx: PostToolUseFailureContext,
397    ) -> Result<HookResult, AgentError> {
398        for entry in &self.hooks {
399            if let HookEntry::PostToolUseFailure(matcher, f) = entry {
400                if !matcher.matches_tool(&ctx.tool_name) {
401                    continue;
402                }
403                if let HookResult::Abort(msg) =
404                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
405                {
406                    return Ok(HookResult::Abort(msg));
407                }
408            }
409        }
410        Ok(HookResult::Continue)
411    }
412
413    /// Run all notification hooks.
414    pub async fn run_notification(
415        &self,
416        ctx: NotificationContext,
417    ) -> Result<HookResult, AgentError> {
418        for entry in &self.hooks {
419            if let HookEntry::Notification(matcher, f) = entry
420                && let HookResult::Abort(msg) =
421                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
422            {
423                return Ok(HookResult::Abort(msg));
424            }
425        }
426        Ok(HookResult::Continue)
427    }
428
429    /// Run all session start hooks.
430    pub async fn run_session_start(
431        &self,
432        ctx: SessionStartContext,
433    ) -> Result<HookResult, AgentError> {
434        for entry in &self.hooks {
435            if let HookEntry::SessionStart(matcher, f) = entry
436                && let HookResult::Abort(msg) =
437                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
438            {
439                return Ok(HookResult::Abort(msg));
440            }
441        }
442        Ok(HookResult::Continue)
443    }
444
445    /// Run all session end hooks.
446    pub async fn run_session_end(&self, ctx: SessionEndContext) -> Result<HookResult, AgentError> {
447        for entry in &self.hooks {
448            if let HookEntry::SessionEnd(matcher, f) = entry
449                && let HookResult::Abort(msg) =
450                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
451            {
452                return Ok(HookResult::Abort(msg));
453            }
454        }
455        Ok(HookResult::Continue)
456    }
457
458    /// Run all subagent start hooks.
459    pub async fn run_subagent_start(
460        &self,
461        ctx: SubagentStartContext,
462    ) -> Result<HookResult, AgentError> {
463        for entry in &self.hooks {
464            if let HookEntry::SubagentStart(matcher, f) = entry
465                && let HookResult::Abort(msg) =
466                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
467            {
468                return Ok(HookResult::Abort(msg));
469            }
470        }
471        Ok(HookResult::Continue)
472    }
473
474    /// Run all subagent stop hooks.
475    pub async fn run_subagent_stop(
476        &self,
477        ctx: SubagentStopContext,
478    ) -> Result<HookResult, AgentError> {
479        for entry in &self.hooks {
480            if let HookEntry::SubagentStop(matcher, f) = entry
481                && let HookResult::Abort(msg) =
482                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
483            {
484                return Ok(HookResult::Abort(msg));
485            }
486        }
487        Ok(HookResult::Continue)
488    }
489
490    /// Run all pre-compact hooks.
491    pub async fn run_pre_compact(&self, ctx: PreCompactContext) -> Result<HookResult, AgentError> {
492        for entry in &self.hooks {
493            if let HookEntry::PreCompact(matcher, f) = entry
494                && let HookResult::Abort(msg) =
495                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
496            {
497                return Ok(HookResult::Abort(msg));
498            }
499        }
500        Ok(HookResult::Continue)
501    }
502
503    /// Run all post-compact hooks.
504    pub async fn run_post_compact(
505        &self,
506        ctx: PostCompactContext,
507    ) -> Result<HookResult, AgentError> {
508        for entry in &self.hooks {
509            if let HookEntry::PostCompact(matcher, f) = entry
510                && let HookResult::Abort(msg) =
511                    run_with_timeout(f(ctx.clone()), matcher.timeout).await
512            {
513                return Ok(HookResult::Abort(msg));
514            }
515        }
516        Ok(HookResult::Continue)
517    }
518}
519
520/// Run a hook future with a timeout; return `Continue` on timeout with a warning.
521async fn run_with_timeout(fut: BoxFuture<'static, HookResult>, duration: Duration) -> HookResult {
522    match timeout(duration, fut).await {
523        Ok(result) => result,
524        Err(_elapsed) => {
525            tracing::warn!(?duration, "Hook timed out — skipping");
526            HookResult::Continue
527        }
528    }
529}
530
531#[cfg(test)]
532#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
533mod tests {
534    use super::*;
535
536    #[tokio::test]
537    async fn test_pre_tool_use_abort() {
538        let mut registry = HookRegistry::new();
539        registry.on_pre_tool_use(HookMatcher::default(), |_ctx| {
540            Box::pin(async { HookResult::Abort("blocked".to_string()) })
541        });
542        let ctx = PreToolUseContext {
543            tool_name: "read_file".to_string(),
544            arguments: serde_json::json!({}),
545        };
546        let result = registry.run_pre_tool_use(ctx).await.unwrap();
547        assert!(matches!(result, HookResult::Abort(_)));
548    }
549
550    #[tokio::test]
551    async fn test_tool_name_pattern_no_match() {
552        let mut registry = HookRegistry::new();
553        registry.on_pre_tool_use(
554            HookMatcher {
555                tool_name_pattern: Some("write_*".to_string()),
556                timeout: Duration::from_secs(5),
557            },
558            |_ctx| Box::pin(async { HookResult::Abort("blocked".to_string()) }),
559        );
560        let ctx = PreToolUseContext {
561            tool_name: "read_file".to_string(),
562            arguments: serde_json::json!({}),
563        };
564        let result = registry.run_pre_tool_use(ctx).await.unwrap();
565        assert!(matches!(result, HookResult::Continue));
566    }
567
568    #[tokio::test]
569    async fn test_timeout_skips_hook() {
570        let mut registry = HookRegistry::new();
571        registry.on_pre_tool_use(
572            HookMatcher {
573                tool_name_pattern: None,
574                timeout: Duration::from_millis(10),
575            },
576            |_ctx| {
577                Box::pin(async {
578                    tokio::time::sleep(Duration::from_secs(10)).await;
579                    HookResult::Abort("late abort".to_string())
580                })
581            },
582        );
583        let ctx = PreToolUseContext {
584            tool_name: "read_file".to_string(),
585            arguments: serde_json::json!({}),
586        };
587        // Hook times out — must not abort.
588        let result = registry.run_pre_tool_use(ctx).await.unwrap();
589        assert!(matches!(result, HookResult::Continue));
590    }
591
592    #[test]
593    fn test_glob_match() {
594        assert!(glob_match("*", "anything"));
595        assert!(glob_match("write_*", "write_file"));
596        assert!(!glob_match("write_*", "read_file"));
597        assert!(glob_match("*_file", "read_file"));
598        assert!(glob_match("exact", "exact"));
599        assert!(!glob_match("exact", "not_exact"));
600    }
601}