1use std::any::Any;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use crate::agent::AgentResult;
10use crate::types::content::Message;
11use crate::types::streaming::StopReason;
12use crate::types::tools::{ToolResult, ToolUse};
13
14#[derive(Debug, Clone)]
16pub struct Interrupt {
17 pub id: String,
18 pub name: String,
19 pub reason: Option<serde_json::Value>,
20 pub response: Option<serde_json::Value>,
21}
22
23impl Interrupt {
24 pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
25 Self {
26 id: id.into(),
27 name: name.into(),
28 reason: None,
29 response: None,
30 }
31 }
32
33 pub fn with_reason(mut self, reason: serde_json::Value) -> Self {
34 self.reason = Some(reason);
35 self
36 }
37}
38
39#[derive(Debug, Clone, Default)]
41pub struct InterruptState {
42 pub interrupts: HashMap<String, Interrupt>,
43}
44
45impl InterruptState {
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 pub fn add_interrupt(&mut self, interrupt: Interrupt) {
51 self.interrupts.insert(interrupt.id.clone(), interrupt);
52 }
53
54 pub fn get_response(&self, id: &str) -> Option<&serde_json::Value> {
55 self.interrupts.get(id).and_then(|i| i.response.as_ref())
56 }
57
58 pub fn set_response(&mut self, id: &str, response: serde_json::Value) {
59 if let Some(interrupt) = self.interrupts.get_mut(id) {
60 interrupt.response = Some(response);
61 }
62 }
63}
64
65pub trait HookEventBase: Send + Sync {
67 fn should_reverse_callbacks(&self) -> bool {
69 false
70 }
71
72 fn as_any(&self) -> &dyn Any;
74
75 fn as_any_mut(&mut self) -> &mut dyn Any;
77}
78
79#[derive(Debug, Clone)]
81pub struct AgentInitializedEvent;
82
83impl HookEventBase for AgentInitializedEvent {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87 fn as_any_mut(&mut self) -> &mut dyn Any {
88 self
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct BeforeInvocationEvent;
95
96impl HookEventBase for BeforeInvocationEvent {
97 fn as_any(&self) -> &dyn Any {
98 self
99 }
100 fn as_any_mut(&mut self) -> &mut dyn Any {
101 self
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct AfterInvocationEvent {
108 pub result: Option<AgentResult>,
109}
110
111impl AfterInvocationEvent {
112 pub fn new(result: Option<AgentResult>) -> Self {
113 Self { result }
114 }
115}
116
117impl HookEventBase for AfterInvocationEvent {
118 fn should_reverse_callbacks(&self) -> bool {
119 true
120 }
121
122 fn as_any(&self) -> &dyn Any {
123 self
124 }
125 fn as_any_mut(&mut self) -> &mut dyn Any {
126 self
127 }
128}
129
130#[derive(Debug, Clone)]
132pub struct MessageAddedEvent {
133 pub message: Message,
134}
135
136impl MessageAddedEvent {
137 pub fn new(message: Message) -> Self {
138 Self { message }
139 }
140}
141
142impl HookEventBase for MessageAddedEvent {
143 fn as_any(&self) -> &dyn Any {
144 self
145 }
146 fn as_any_mut(&mut self) -> &mut dyn Any {
147 self
148 }
149}
150
151pub trait Interruptible {
153 fn interrupt_id(&self, name: &str) -> String;
158}
159
160#[derive(Debug, Clone)]
162pub struct BeforeToolCallEvent {
163 pub tool_use: ToolUse,
164 pub invocation_state: HashMap<String, serde_json::Value>,
165 pub cancel_tool: Option<String>,
166}
167
168impl BeforeToolCallEvent {
169 pub fn new(tool_use: ToolUse) -> Self {
170 Self {
171 tool_use,
172 invocation_state: HashMap::new(),
173 cancel_tool: None,
174 }
175 }
176
177 pub fn with_state(mut self, state: HashMap<String, serde_json::Value>) -> Self {
178 self.invocation_state = state;
179 self
180 }
181
182 pub fn cancel(&mut self, message: impl Into<String>) {
184 self.cancel_tool = Some(message.into());
185 }
186}
187
188impl Interruptible for BeforeToolCallEvent {
189 fn interrupt_id(&self, name: &str) -> String {
193 use uuid::Uuid;
194 let name_uuid = Uuid::new_v5(&Uuid::NAMESPACE_OID, name.as_bytes());
195 format!(
196 "v1:before_tool_call:{}:{}",
197 self.tool_use.tool_use_id, name_uuid
198 )
199 }
200}
201
202impl HookEventBase for BeforeToolCallEvent {
203 fn as_any(&self) -> &dyn Any {
204 self
205 }
206 fn as_any_mut(&mut self) -> &mut dyn Any {
207 self
208 }
209}
210
211#[derive(Debug, Clone)]
213pub struct AfterToolCallEvent {
214 pub tool_use: ToolUse,
215 pub invocation_state: HashMap<String, serde_json::Value>,
216 pub result: ToolResult,
217 pub exception: Option<String>,
218 pub cancel_message: Option<String>,
219}
220
221impl AfterToolCallEvent {
222 pub fn new(tool_use: ToolUse, result: ToolResult) -> Self {
223 Self {
224 tool_use,
225 invocation_state: HashMap::new(),
226 result,
227 exception: None,
228 cancel_message: None,
229 }
230 }
231
232 pub fn with_exception(mut self, exception: String) -> Self {
233 self.exception = Some(exception);
234 self
235 }
236}
237
238impl HookEventBase for AfterToolCallEvent {
239 fn should_reverse_callbacks(&self) -> bool {
240 true
241 }
242
243 fn as_any(&self) -> &dyn Any {
244 self
245 }
246 fn as_any_mut(&mut self) -> &mut dyn Any {
247 self
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct BeforeModelCallEvent;
254
255impl HookEventBase for BeforeModelCallEvent {
256 fn as_any(&self) -> &dyn Any {
257 self
258 }
259 fn as_any_mut(&mut self) -> &mut dyn Any {
260 self
261 }
262}
263
264#[derive(Debug, Clone)]
266pub struct ModelStopResponse {
267 pub message: Message,
268 pub stop_reason: StopReason,
269}
270
271#[derive(Debug, Clone)]
273pub struct AfterModelCallEvent {
274 pub stop_response: Option<ModelStopResponse>,
275 pub exception: Option<String>,
276}
277
278impl AfterModelCallEvent {
279 pub fn success(message: Message, stop_reason: StopReason) -> Self {
280 Self {
281 stop_response: Some(ModelStopResponse {
282 message,
283 stop_reason,
284 }),
285 exception: None,
286 }
287 }
288
289 pub fn error(exception: String) -> Self {
290 Self {
291 stop_response: None,
292 exception: Some(exception),
293 }
294 }
295}
296
297impl HookEventBase for AfterModelCallEvent {
298 fn should_reverse_callbacks(&self) -> bool {
299 true
300 }
301
302 fn as_any(&self) -> &dyn Any {
303 self
304 }
305 fn as_any_mut(&mut self) -> &mut dyn Any {
306 self
307 }
308}
309
310#[derive(Debug, Clone)]
312pub enum HookEvent {
313 AgentInitialized(AgentInitializedEvent),
314 BeforeInvocation(BeforeInvocationEvent),
315 AfterInvocation(AfterInvocationEvent),
316 MessageAdded(MessageAddedEvent),
317 BeforeToolCall(BeforeToolCallEvent),
318 AfterToolCall(AfterToolCallEvent),
319 BeforeModelCall(BeforeModelCallEvent),
320 AfterModelCall(AfterModelCallEvent),
321}
322
323impl HookEvent {
324 pub fn should_reverse_callbacks(&self) -> bool {
325 match self {
326 Self::AfterInvocation(_) | Self::AfterToolCall(_) | Self::AfterModelCall(_) => true,
327 _ => false,
328 }
329 }
330}
331
332#[async_trait]
334pub trait HookProvider: Send + Sync {
335 async fn on_event(&self, event: &HookEvent);
337}
338
339pub type HookCallback = Arc<dyn Fn(&HookEvent) + Send + Sync>;
341
342pub type AsyncHookCallback = Arc<dyn Fn(&HookEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
344
345#[derive(Default)]
347pub struct HookRegistry {
348 providers: Vec<Arc<dyn HookProvider>>,
349 callbacks: Vec<HookCallback>,
350 async_callbacks: Vec<AsyncHookCallback>,
351}
352
353impl HookRegistry {
354 pub fn new() -> Self {
355 Self::default()
356 }
357
358 pub fn add_provider(&mut self, provider: impl HookProvider + 'static) {
360 self.providers.push(Arc::new(provider));
361 }
362
363 pub fn add_provider_arc(&mut self, provider: Arc<dyn HookProvider>) {
365 self.providers.push(provider);
366 }
367
368 pub fn add_callback<F>(&mut self, callback: F)
370 where
371 F: Fn(&HookEvent) + Send + Sync + 'static,
372 {
373 self.callbacks.push(Arc::new(callback));
374 }
375
376 pub fn add_async_callback<F, Fut>(&mut self, callback: F)
378 where
379 F: Fn(&HookEvent) -> Fut + Send + Sync + 'static,
380 Fut: std::future::Future<Output = ()> + Send + 'static,
381 {
382 self.async_callbacks.push(Arc::new(move |event| {
383 Box::pin(callback(event))
384 }));
385 }
386
387 pub async fn invoke(&self, event: &HookEvent) -> Vec<Interrupt> {
389 let interrupts = Vec::new();
390
391 let reverse = event.should_reverse_callbacks();
392
393 if reverse {
394 for callback in self.callbacks.iter().rev() {
395 callback(event);
396 }
397 } else {
398 for callback in &self.callbacks {
399 callback(event);
400 }
401 }
402
403 if reverse {
404 for callback in self.async_callbacks.iter().rev() {
405 callback(event).await;
406 }
407 } else {
408 for callback in &self.async_callbacks {
409 callback(event).await;
410 }
411 }
412
413 if reverse {
414 for provider in self.providers.iter().rev() {
415 provider.on_event(event).await;
416 }
417 } else {
418 for provider in &self.providers {
419 provider.on_event(event).await;
420 }
421 }
422
423 interrupts
424 }
425
426 pub fn invoke_sync(&self, event: &HookEvent) -> Vec<Interrupt> {
428 if !self.async_callbacks.is_empty() {
429 panic!("Cannot invoke sync with async callbacks registered");
430 }
431
432 let interrupts = Vec::new();
433 let reverse = event.should_reverse_callbacks();
434
435 if reverse {
436 for callback in self.callbacks.iter().rev() {
437 callback(event);
438 }
439 } else {
440 for callback in &self.callbacks {
441 callback(event);
442 }
443 }
444
445 interrupts
446 }
447
448 pub fn has_callbacks(&self) -> bool {
449 !self.providers.is_empty() || !self.callbacks.is_empty() || !self.async_callbacks.is_empty()
450 }
451
452 pub fn len(&self) -> usize {
453 self.providers.len() + self.callbacks.len() + self.async_callbacks.len()
454 }
455
456 pub fn is_empty(&self) -> bool {
457 self.len() == 0
458 }
459}