1use std::time::Duration;
41
42use async_trait::async_trait;
43
44use crate::agent::AgentOutput;
45use crate::types::completion::{CompletionRequest, CompletionResponse};
46
47#[derive(Debug, Clone)]
52pub enum HookAction {
53 Continue,
55 Block(String),
60}
61
62#[async_trait]
74pub trait AgentHook: Send + Sync + 'static {
75 async fn on_agent_start(&self, _input: &str) {}
77
78 async fn on_agent_end(&self, _output: &AgentOutput, _duration: Duration) {}
80
81 async fn on_provider_start(&self, _request: &CompletionRequest) {}
83
84 async fn on_provider_end(&self, _response: &CompletionResponse, _duration: Duration) {}
86
87 async fn before_tool_execute(&self, _name: &str, _args: &serde_json::Value) -> HookAction {
92 HookAction::Continue
93 }
94
95 async fn after_tool_execute(&self, _name: &str, _result: &str, _duration: Duration) {}
97
98 async fn on_stream_chunk(&self, _chunk: &str) {}
100
101 async fn on_error(&self, _error: &crate::Error) {}
103}
104
105#[async_trait]
110impl<T: AgentHook> AgentHook for std::sync::Arc<T> {
111 async fn on_agent_start(&self, input: &str) {
112 (**self).on_agent_start(input).await;
113 }
114 async fn on_agent_end(&self, output: &AgentOutput, duration: Duration) {
115 (**self).on_agent_end(output, duration).await;
116 }
117 async fn on_provider_start(&self, request: &CompletionRequest) {
118 (**self).on_provider_start(request).await;
119 }
120 async fn on_provider_end(&self, response: &CompletionResponse, duration: Duration) {
121 (**self).on_provider_end(response, duration).await;
122 }
123 async fn before_tool_execute(&self, name: &str, args: &serde_json::Value) -> HookAction {
124 (**self).before_tool_execute(name, args).await
125 }
126 async fn after_tool_execute(&self, name: &str, result: &str, duration: Duration) {
127 (**self).after_tool_execute(name, result, duration).await;
128 }
129 async fn on_stream_chunk(&self, chunk: &str) {
130 (**self).on_stream_chunk(chunk).await;
131 }
132 async fn on_error(&self, error: &crate::Error) {
133 (**self).on_error(error).await;
134 }
135}
136
137pub struct LoggingHook {
150 level: tracing::Level,
151}
152
153impl LoggingHook {
154 #[must_use]
156 pub fn new(level: tracing::Level) -> Self {
157 Self { level }
158 }
159}
160
161#[async_trait]
162impl AgentHook for LoggingHook {
163 async fn on_agent_start(&self, input: &str) {
164 match self.level {
165 tracing::Level::TRACE => tracing::trace!(input_len = input.len(), "Agent starting"),
166 tracing::Level::DEBUG => tracing::debug!(input_len = input.len(), "Agent starting"),
167 _ => tracing::info!(input_len = input.len(), "Agent starting"),
168 }
169 }
170
171 async fn on_agent_end(&self, _output: &AgentOutput, duration: Duration) {
172 #[allow(clippy::cast_possible_truncation)]
173 let ms = duration.as_millis() as u64;
174 match self.level {
175 tracing::Level::TRACE => tracing::trace!(duration_ms = ms, "Agent completed"),
176 tracing::Level::DEBUG => tracing::debug!(duration_ms = ms, "Agent completed"),
177 _ => tracing::info!(duration_ms = ms, "Agent completed"),
178 }
179 }
180
181 async fn on_provider_start(&self, _request: &CompletionRequest) {
182 match self.level {
183 tracing::Level::TRACE => tracing::trace!("LLM call starting"),
184 tracing::Level::DEBUG => tracing::debug!("LLM call starting"),
185 _ => tracing::info!("LLM call starting"),
186 }
187 }
188
189 async fn on_provider_end(&self, response: &CompletionResponse, duration: Duration) {
190 #[allow(clippy::cast_possible_truncation)]
191 let ms = duration.as_millis() as u64;
192 let tokens = response.usage.total_tokens;
193 match self.level {
194 tracing::Level::TRACE => {
195 tracing::trace!(duration_ms = ms, tokens, "LLM call completed")
196 }
197 tracing::Level::DEBUG => {
198 tracing::debug!(duration_ms = ms, tokens, "LLM call completed")
199 }
200 _ => tracing::info!(duration_ms = ms, tokens, "LLM call completed"),
201 }
202 }
203
204 async fn before_tool_execute(&self, name: &str, _args: &serde_json::Value) -> HookAction {
205 match self.level {
206 tracing::Level::TRACE => tracing::trace!(tool = name, "Tool executing"),
207 tracing::Level::DEBUG => tracing::debug!(tool = name, "Tool executing"),
208 _ => tracing::info!(tool = name, "Tool executing"),
209 }
210 HookAction::Continue
211 }
212
213 async fn after_tool_execute(&self, name: &str, _result: &str, duration: Duration) {
214 #[allow(clippy::cast_possible_truncation)]
215 let ms = duration.as_millis() as u64;
216 match self.level {
217 tracing::Level::TRACE => tracing::trace!(tool = name, duration_ms = ms, "Tool done"),
218 tracing::Level::DEBUG => tracing::debug!(tool = name, duration_ms = ms, "Tool done"),
219 _ => tracing::info!(tool = name, duration_ms = ms, "Tool done"),
220 }
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 fn _assert_object_safe(_: &dyn AgentHook) {}
230
231 #[test]
232 fn test_hook_action_variants() {
233 let cont = HookAction::Continue;
234 assert!(matches!(cont, HookAction::Continue));
235
236 let block = HookAction::Block("reason".into());
237 assert!(matches!(block, HookAction::Block(r) if r == "reason"));
238 }
239
240 #[test]
241 fn test_logging_hook_creation() {
242 let hook = LoggingHook::new(tracing::Level::INFO);
243 assert_eq!(hook.level, tracing::Level::INFO);
244 }
245}