strands_agents/hooks/
mod.rs

1//! Hook system for agent lifecycle events.
2
3use std::any::Any;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use crate::agent::AgentResult;
10use crate::types::content::Message;
11use crate::types::streaming::StopReason;
12use crate::types::tools::{ToolResult, ToolUse};
13
14/// Interrupt for human-in-the-loop workflows.
15#[derive(Debug, Clone)]
16pub struct Interrupt {
17    pub id: String,
18    pub name: String,
19    pub reason: Option<serde_json::Value>,
20    pub response: Option<serde_json::Value>,
21}
22
23impl Interrupt {
24    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
25        Self {
26            id: id.into(),
27            name: name.into(),
28            reason: None,
29            response: None,
30        }
31    }
32
33    pub fn with_reason(mut self, reason: serde_json::Value) -> Self {
34        self.reason = Some(reason);
35        self
36    }
37}
38
39/// State for managing interrupts during agent execution.
40#[derive(Debug, Clone, Default)]
41pub struct InterruptState {
42    pub interrupts: HashMap<String, Interrupt>,
43}
44
45impl InterruptState {
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    pub fn add_interrupt(&mut self, interrupt: Interrupt) {
51        self.interrupts.insert(interrupt.id.clone(), interrupt);
52    }
53
54    pub fn get_response(&self, id: &str) -> Option<&serde_json::Value> {
55        self.interrupts.get(id).and_then(|i| i.response.as_ref())
56    }
57
58    pub fn set_response(&mut self, id: &str, response: serde_json::Value) {
59        if let Some(interrupt) = self.interrupts.get_mut(id) {
60            interrupt.response = Some(response);
61        }
62    }
63}
64
65/// Base trait for hook events.
66pub trait HookEventBase: Send + Sync {
67    /// Whether callbacks should be invoked in reverse order (for cleanup events).
68    fn should_reverse_callbacks(&self) -> bool {
69        false
70    }
71
72    /// Returns the event as Any for downcasting.
73    fn as_any(&self) -> &dyn Any;
74
75    /// Returns the event as mutable Any for downcasting.
76    fn as_any_mut(&mut self) -> &mut dyn Any;
77}
78
79/// Event triggered when an agent has finished initialization.
80#[derive(Debug, Clone)]
81pub struct AgentInitializedEvent;
82
83impl HookEventBase for AgentInitializedEvent {
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87    fn as_any_mut(&mut self) -> &mut dyn Any {
88        self
89    }
90}
91
92/// Event triggered at the beginning of a new agent request.
93#[derive(Debug, Clone)]
94pub struct BeforeInvocationEvent;
95
96impl HookEventBase for BeforeInvocationEvent {
97    fn as_any(&self) -> &dyn Any {
98        self
99    }
100    fn as_any_mut(&mut self) -> &mut dyn Any {
101        self
102    }
103}
104
105/// Event triggered at the end of an agent request.
106#[derive(Debug, Clone)]
107pub struct AfterInvocationEvent {
108    pub result: Option<AgentResult>,
109}
110
111impl AfterInvocationEvent {
112    pub fn new(result: Option<AgentResult>) -> Self {
113        Self { result }
114    }
115}
116
117impl HookEventBase for AfterInvocationEvent {
118    fn should_reverse_callbacks(&self) -> bool {
119        true
120    }
121
122    fn as_any(&self) -> &dyn Any {
123        self
124    }
125    fn as_any_mut(&mut self) -> &mut dyn Any {
126        self
127    }
128}
129
130/// Event triggered when a message is added to the conversation.
131#[derive(Debug, Clone)]
132pub struct MessageAddedEvent {
133    pub message: Message,
134}
135
136impl MessageAddedEvent {
137    pub fn new(message: Message) -> Self {
138        Self { message }
139    }
140}
141
142impl HookEventBase for MessageAddedEvent {
143    fn as_any(&self) -> &dyn Any {
144        self
145    }
146    fn as_any_mut(&mut self) -> &mut dyn Any {
147        self
148    }
149}
150
151/// Trait for events that can generate interrupts.
152pub trait Interruptible {
153    /// Generate a unique interrupt ID for this event.
154    ///
155    /// The ID should be deterministic based on the event context and name,
156    /// allowing for consistent interrupt handling across sessions.
157    fn interrupt_id(&self, name: &str) -> String;
158}
159
160/// Event triggered before a tool is invoked.
161#[derive(Debug, Clone)]
162pub struct BeforeToolCallEvent {
163    pub tool_use: ToolUse,
164    pub invocation_state: HashMap<String, serde_json::Value>,
165    pub cancel_tool: Option<String>,
166}
167
168impl BeforeToolCallEvent {
169    pub fn new(tool_use: ToolUse) -> Self {
170        Self {
171            tool_use,
172            invocation_state: HashMap::new(),
173            cancel_tool: None,
174        }
175    }
176
177    pub fn with_state(mut self, state: HashMap<String, serde_json::Value>) -> Self {
178        self.invocation_state = state;
179        self
180    }
181
182    /// Cancel the tool call with a message.
183    pub fn cancel(&mut self, message: impl Into<String>) {
184        self.cancel_tool = Some(message.into());
185    }
186}
187
188impl Interruptible for BeforeToolCallEvent {
189    /// Generate a unique interrupt ID for before tool call events.
190    ///
191    /// Format: `v1:before_tool_call:{tool_use_id}:{uuid5(name)}`
192    fn interrupt_id(&self, name: &str) -> String {
193        use uuid::Uuid;
194        let name_uuid = Uuid::new_v5(&Uuid::NAMESPACE_OID, name.as_bytes());
195        format!(
196            "v1:before_tool_call:{}:{}",
197            self.tool_use.tool_use_id, name_uuid
198        )
199    }
200}
201
202impl HookEventBase for BeforeToolCallEvent {
203    fn as_any(&self) -> &dyn Any {
204        self
205    }
206    fn as_any_mut(&mut self) -> &mut dyn Any {
207        self
208    }
209}
210
211/// Event triggered after a tool invocation completes.
212#[derive(Debug, Clone)]
213pub struct AfterToolCallEvent {
214    pub tool_use: ToolUse,
215    pub invocation_state: HashMap<String, serde_json::Value>,
216    pub result: ToolResult,
217    pub exception: Option<String>,
218    pub cancel_message: Option<String>,
219}
220
221impl AfterToolCallEvent {
222    pub fn new(tool_use: ToolUse, result: ToolResult) -> Self {
223        Self {
224            tool_use,
225            invocation_state: HashMap::new(),
226            result,
227            exception: None,
228            cancel_message: None,
229        }
230    }
231
232    pub fn with_exception(mut self, exception: String) -> Self {
233        self.exception = Some(exception);
234        self
235    }
236}
237
238impl HookEventBase for AfterToolCallEvent {
239    fn should_reverse_callbacks(&self) -> bool {
240        true
241    }
242
243    fn as_any(&self) -> &dyn Any {
244        self
245    }
246    fn as_any_mut(&mut self) -> &mut dyn Any {
247        self
248    }
249}
250
251/// Event triggered before the model is invoked.
252#[derive(Debug, Clone)]
253pub struct BeforeModelCallEvent;
254
255impl HookEventBase for BeforeModelCallEvent {
256    fn as_any(&self) -> &dyn Any {
257        self
258    }
259    fn as_any_mut(&mut self) -> &mut dyn Any {
260        self
261    }
262}
263
264/// Model stop response data.
265#[derive(Debug, Clone)]
266pub struct ModelStopResponse {
267    pub message: Message,
268    pub stop_reason: StopReason,
269}
270
271/// Event triggered after the model invocation completes.
272#[derive(Debug, Clone)]
273pub struct AfterModelCallEvent {
274    pub stop_response: Option<ModelStopResponse>,
275    pub exception: Option<String>,
276}
277
278impl AfterModelCallEvent {
279    pub fn success(message: Message, stop_reason: StopReason) -> Self {
280        Self {
281            stop_response: Some(ModelStopResponse {
282                message,
283                stop_reason,
284            }),
285            exception: None,
286        }
287    }
288
289    pub fn error(exception: String) -> Self {
290        Self {
291            stop_response: None,
292            exception: Some(exception),
293        }
294    }
295}
296
297impl HookEventBase for AfterModelCallEvent {
298    fn should_reverse_callbacks(&self) -> bool {
299        true
300    }
301
302    fn as_any(&self) -> &dyn Any {
303        self
304    }
305    fn as_any_mut(&mut self) -> &mut dyn Any {
306        self
307    }
308}
309
310/// Enum wrapper for all hook events.
311#[derive(Debug, Clone)]
312pub enum HookEvent {
313    AgentInitialized(AgentInitializedEvent),
314    BeforeInvocation(BeforeInvocationEvent),
315    AfterInvocation(AfterInvocationEvent),
316    MessageAdded(MessageAddedEvent),
317    BeforeToolCall(BeforeToolCallEvent),
318    AfterToolCall(AfterToolCallEvent),
319    BeforeModelCall(BeforeModelCallEvent),
320    AfterModelCall(AfterModelCallEvent),
321}
322
323impl HookEvent {
324    pub fn should_reverse_callbacks(&self) -> bool {
325        match self {
326            Self::AfterInvocation(_) | Self::AfterToolCall(_) | Self::AfterModelCall(_) => true,
327            _ => false,
328        }
329    }
330}
331
332/// Trait for implementing hook providers.
333#[async_trait]
334pub trait HookProvider: Send + Sync {
335    /// Called when a hook event occurs.
336    async fn on_event(&self, event: &HookEvent);
337}
338
339/// Callback function type for hook events.
340pub type HookCallback = Arc<dyn Fn(&HookEvent) + Send + Sync>;
341
342/// Async callback function type for hook events.
343pub type AsyncHookCallback = Arc<dyn Fn(&HookEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send + Sync>;
344
345/// Registry for managing hook providers and callbacks.
346#[derive(Default)]
347pub struct HookRegistry {
348    providers: Vec<Arc<dyn HookProvider>>,
349    callbacks: Vec<HookCallback>,
350    async_callbacks: Vec<AsyncHookCallback>,
351}
352
353impl HookRegistry {
354    pub fn new() -> Self {
355        Self::default()
356    }
357
358    /// Add a hook provider.
359    pub fn add_provider(&mut self, provider: impl HookProvider + 'static) {
360        self.providers.push(Arc::new(provider));
361    }
362
363    /// Add a hook provider as Arc.
364    pub fn add_provider_arc(&mut self, provider: Arc<dyn HookProvider>) {
365        self.providers.push(provider);
366    }
367
368    /// Add a synchronous callback.
369    pub fn add_callback<F>(&mut self, callback: F)
370    where
371        F: Fn(&HookEvent) + Send + Sync + 'static,
372    {
373        self.callbacks.push(Arc::new(callback));
374    }
375
376    /// Add an async callback.
377    pub fn add_async_callback<F, Fut>(&mut self, callback: F)
378    where
379        F: Fn(&HookEvent) -> Fut + Send + Sync + 'static,
380        Fut: std::future::Future<Output = ()> + Send + 'static,
381    {
382        self.async_callbacks.push(Arc::new(move |event| {
383            Box::pin(callback(event))
384        }));
385    }
386
387    /// Invoke all callbacks for an event.
388    pub async fn invoke(&self, event: &HookEvent) -> Vec<Interrupt> {
389        let interrupts = Vec::new();
390
391        let reverse = event.should_reverse_callbacks();
392
393        if reverse {
394            for callback in self.callbacks.iter().rev() {
395                callback(event);
396            }
397        } else {
398            for callback in &self.callbacks {
399                callback(event);
400            }
401        }
402
403        if reverse {
404            for callback in self.async_callbacks.iter().rev() {
405                callback(event).await;
406            }
407        } else {
408            for callback in &self.async_callbacks {
409                callback(event).await;
410            }
411        }
412
413        if reverse {
414            for provider in self.providers.iter().rev() {
415                provider.on_event(event).await;
416            }
417        } else {
418            for provider in &self.providers {
419                provider.on_event(event).await;
420            }
421        }
422
423        interrupts
424    }
425
426    /// Invoke callbacks synchronously (panics if async callbacks exist).
427    pub fn invoke_sync(&self, event: &HookEvent) -> Vec<Interrupt> {
428        if !self.async_callbacks.is_empty() {
429            panic!("Cannot invoke sync with async callbacks registered");
430        }
431
432        let interrupts = Vec::new();
433        let reverse = event.should_reverse_callbacks();
434
435        if reverse {
436            for callback in self.callbacks.iter().rev() {
437                callback(event);
438            }
439        } else {
440            for callback in &self.callbacks {
441                callback(event);
442            }
443        }
444
445        interrupts
446    }
447
448    pub fn has_callbacks(&self) -> bool {
449        !self.providers.is_empty() || !self.callbacks.is_empty() || !self.async_callbacks.is_empty()
450    }
451
452    pub fn len(&self) -> usize {
453        self.providers.len() + self.callbacks.len() + self.async_callbacks.len()
454    }
455
456    pub fn is_empty(&self) -> bool {
457        self.len() == 0
458    }
459}