Skip to main content

traitclaw_core/traits/
hook.rs

1//! Agent Hook — async lifecycle hooks for observability & interception.
2//!
3//! The `AgentHook` trait provides lifecycle callbacks that fire at key
4//! points during agent execution: before/after LLM calls, before/after
5//! tool execution, and on errors.
6//!
7//! # Hook vs Tracker
8//!
9//! - **Tracker** — Internal state monitoring for the Steering subsystem
10//!   (Guard/Hint auto-configuration). Sync, lightweight, modifies `AgentState`.
11//! - **Hook** — External observability and interception. Async, can perform
12//!   I/O (send metrics to DataDog, write logs, etc.), does NOT modify `AgentState`.
13//!
14//! # Architecture Decision
15//!
16//! Hooks are `async fn` using Rust 1.75+ native `async fn` in traits.
17//! This allows non-blocking I/O (HTTP calls to observability services)
18//! without blocking the agent loop.
19//!
20//! # Example
21//!
22//! ```rust,no_run
23//! use traitclaw_core::traits::hook::AgentHook;
24//! use std::time::Duration;
25//!
26//! struct TimingHook;
27//!
28//! #[async_trait::async_trait]
29//! impl AgentHook for TimingHook {
30//!     async fn on_provider_end(
31//!         &self,
32//!         _response: &traitclaw_core::types::completion::CompletionResponse,
33//!         duration: Duration,
34//!     ) {
35//!         println!("LLM call took {duration:?}");
36//!     }
37//! }
38//! ```
39
40use std::time::Duration;
41
42use async_trait::async_trait;
43
44use crate::agent::AgentOutput;
45use crate::types::completion::{CompletionRequest, CompletionResponse};
46
47/// The result of a hook's interception decision.
48///
49/// Returned by [`AgentHook::before_tool_execute`] to allow or block
50/// a tool execution.
51#[derive(Debug, Clone)]
52pub enum HookAction {
53    /// Allow the tool to execute normally.
54    Continue,
55    /// Block the tool execution with the given reason.
56    ///
57    /// The reason string is returned to the LLM as the tool result,
58    /// allowing it to adapt its behavior.
59    Block(String),
60}
61
62/// Async lifecycle hooks for agent observability and interception.
63///
64/// All methods have default empty implementations, so you only need
65/// to override the hooks you care about.
66///
67/// Multiple hooks can be registered on a single agent and are called
68/// sequentially in registration order.
69///
70/// # Object Safety
71///
72/// This trait is object-safe and used as `Vec<Box<dyn AgentHook>>`.
73#[async_trait]
74pub trait AgentHook: Send + Sync + 'static {
75    /// Called when the agent starts processing input.
76    async fn on_agent_start(&self, _input: &str) {}
77
78    /// Called when the agent finishes processing.
79    async fn on_agent_end(&self, _output: &AgentOutput, _duration: Duration) {}
80
81    /// Called before each LLM call is made.
82    async fn on_provider_start(&self, _request: &CompletionRequest) {}
83
84    /// Called after each LLM call completes.
85    async fn on_provider_end(&self, _response: &CompletionResponse, _duration: Duration) {}
86
87    /// Called before a tool is executed.
88    ///
89    /// Return [`HookAction::Block`] to prevent execution. The block
90    /// reason is returned to the LLM as the tool result.
91    async fn before_tool_execute(&self, _name: &str, _args: &serde_json::Value) -> HookAction {
92        HookAction::Continue
93    }
94
95    /// Called after a tool finishes executing.
96    async fn after_tool_execute(&self, _name: &str, _result: &str, _duration: Duration) {}
97
98    /// Called for each streaming chunk received.
99    async fn on_stream_chunk(&self, _chunk: &str) {}
100
101    /// Called when an error occurs during execution.
102    async fn on_error(&self, _error: &crate::Error) {}
103}
104
105/// Blanket implementation: `Arc<T>` delegates to `T`.
106///
107/// This enables sharing a hook instance across multiple agents via `Arc`,
108/// useful for recording hooks or metrics collectors.
109#[async_trait]
110impl<T: AgentHook> AgentHook for std::sync::Arc<T> {
111    async fn on_agent_start(&self, input: &str) {
112        (**self).on_agent_start(input).await;
113    }
114    async fn on_agent_end(&self, output: &AgentOutput, duration: Duration) {
115        (**self).on_agent_end(output, duration).await;
116    }
117    async fn on_provider_start(&self, request: &CompletionRequest) {
118        (**self).on_provider_start(request).await;
119    }
120    async fn on_provider_end(&self, response: &CompletionResponse, duration: Duration) {
121        (**self).on_provider_end(response, duration).await;
122    }
123    async fn before_tool_execute(&self, name: &str, args: &serde_json::Value) -> HookAction {
124        (**self).before_tool_execute(name, args).await
125    }
126    async fn after_tool_execute(&self, name: &str, result: &str, duration: Duration) {
127        (**self).after_tool_execute(name, result, duration).await;
128    }
129    async fn on_stream_chunk(&self, chunk: &str) {
130        (**self).on_stream_chunk(chunk).await;
131    }
132    async fn on_error(&self, error: &crate::Error) {
133        (**self).on_error(error).await;
134    }
135}
136
137/// A hook that logs all lifecycle events using `tracing`.
138///
139/// # Example
140///
141/// ```rust,no_run
142/// use traitclaw_core::traits::hook::LoggingHook;
143///
144/// // let agent = Agent::builder()
145/// //     .model(my_provider)
146/// //     .hook(LoggingHook::new(tracing::Level::INFO))
147/// //     .build()?;
148/// ```
149pub struct LoggingHook {
150    level: tracing::Level,
151}
152
153impl LoggingHook {
154    /// Create a new logging hook at the given tracing level.
155    #[must_use]
156    pub fn new(level: tracing::Level) -> Self {
157        Self { level }
158    }
159}
160
161#[async_trait]
162impl AgentHook for LoggingHook {
163    async fn on_agent_start(&self, input: &str) {
164        match self.level {
165            tracing::Level::TRACE => tracing::trace!(input_len = input.len(), "Agent starting"),
166            tracing::Level::DEBUG => tracing::debug!(input_len = input.len(), "Agent starting"),
167            _ => tracing::info!(input_len = input.len(), "Agent starting"),
168        }
169    }
170
171    async fn on_agent_end(&self, _output: &AgentOutput, duration: Duration) {
172        #[allow(clippy::cast_possible_truncation)]
173        let ms = duration.as_millis() as u64;
174        match self.level {
175            tracing::Level::TRACE => tracing::trace!(duration_ms = ms, "Agent completed"),
176            tracing::Level::DEBUG => tracing::debug!(duration_ms = ms, "Agent completed"),
177            _ => tracing::info!(duration_ms = ms, "Agent completed"),
178        }
179    }
180
181    async fn on_provider_start(&self, _request: &CompletionRequest) {
182        match self.level {
183            tracing::Level::TRACE => tracing::trace!("LLM call starting"),
184            tracing::Level::DEBUG => tracing::debug!("LLM call starting"),
185            _ => tracing::info!("LLM call starting"),
186        }
187    }
188
189    async fn on_provider_end(&self, response: &CompletionResponse, duration: Duration) {
190        #[allow(clippy::cast_possible_truncation)]
191        let ms = duration.as_millis() as u64;
192        let tokens = response.usage.total_tokens;
193        match self.level {
194            tracing::Level::TRACE => {
195                tracing::trace!(duration_ms = ms, tokens, "LLM call completed")
196            }
197            tracing::Level::DEBUG => {
198                tracing::debug!(duration_ms = ms, tokens, "LLM call completed")
199            }
200            _ => tracing::info!(duration_ms = ms, tokens, "LLM call completed"),
201        }
202    }
203
204    async fn before_tool_execute(&self, name: &str, _args: &serde_json::Value) -> HookAction {
205        match self.level {
206            tracing::Level::TRACE => tracing::trace!(tool = name, "Tool executing"),
207            tracing::Level::DEBUG => tracing::debug!(tool = name, "Tool executing"),
208            _ => tracing::info!(tool = name, "Tool executing"),
209        }
210        HookAction::Continue
211    }
212
213    async fn after_tool_execute(&self, name: &str, _result: &str, duration: Duration) {
214        #[allow(clippy::cast_possible_truncation)]
215        let ms = duration.as_millis() as u64;
216        match self.level {
217            tracing::Level::TRACE => tracing::trace!(tool = name, duration_ms = ms, "Tool done"),
218            tracing::Level::DEBUG => tracing::debug!(tool = name, duration_ms = ms, "Tool done"),
219            _ => tracing::info!(tool = name, duration_ms = ms, "Tool done"),
220        }
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    // Verify trait is object-safe
229    fn _assert_object_safe(_: &dyn AgentHook) {}
230
231    #[test]
232    fn test_hook_action_variants() {
233        let cont = HookAction::Continue;
234        assert!(matches!(cont, HookAction::Continue));
235
236        let block = HookAction::Block("reason".into());
237        assert!(matches!(block, HookAction::Block(r) if r == "reason"));
238    }
239
240    #[test]
241    fn test_logging_hook_creation() {
242        let hook = LoggingHook::new(tracing::Level::INFO);
243        assert_eq!(hook.level, tracing::Level::INFO);
244    }
245}