strands_agents/types/
interrupt.rs

1//! Interrupt type definitions for human-in-the-loop workflows.
2
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7/// An interrupt for pausing execution and requesting human input.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Interrupt {
10    pub id: String,
11    pub name: String,
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub reason: Option<serde_json::Value>,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub response: Option<serde_json::Value>,
16}
17
18impl Interrupt {
19    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
20        Self {
21            id: id.into(),
22            name: name.into(),
23            reason: None,
24            response: None,
25        }
26    }
27
28    pub fn with_reason(mut self, reason: serde_json::Value) -> Self {
29        self.reason = Some(reason);
30        self
31    }
32
33    pub fn with_response(mut self, response: serde_json::Value) -> Self {
34        self.response = Some(response);
35        self
36    }
37
38    pub fn has_response(&self) -> bool {
39        self.response.is_some()
40    }
41}
42
43/// State for managing interrupts during agent execution.
44#[derive(Debug, Clone, Default, Serialize, Deserialize)]
45pub struct InterruptState {
46    /// Interrupts raised by the user.
47    pub interrupts: HashMap<String, Interrupt>,
48    /// Additional context associated with an interrupt event.
49    pub context: HashMap<String, serde_json::Value>,
50    /// True if agent is in an interrupt state, False otherwise.
51    pub activated: bool,
52}
53
54impl InterruptState {
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Activate the interrupt state.
60    pub fn activate(&mut self) {
61        self.activated = true;
62    }
63
64    /// Deactivate the interrupt state.
65    ///
66    /// Interrupts and context are cleared.
67    pub fn deactivate(&mut self) {
68        self.interrupts.clear();
69        self.context.clear();
70        self.activated = false;
71    }
72
73    /// Configure the interrupt state if resuming from an interrupt event.
74    pub fn resume(&mut self, responses: Vec<InterruptResponseContent>) -> Result<(), String> {
75        if !self.activated {
76            return Ok(());
77        }
78
79        for content in &responses {
80            let interrupt_id = &content.interrupt_response.interrupt_id;
81            let interrupt_response = &content.interrupt_response.response;
82
83            if let Some(interrupt) = self.interrupts.get_mut(interrupt_id) {
84                interrupt.response = Some(interrupt_response.clone());
85            } else {
86                return Err(format!("interrupt_id=<{}> | no interrupt found", interrupt_id));
87            }
88        }
89
90        self.context.insert(
91            "responses".to_string(),
92            serde_json::to_value(&responses).unwrap_or_default(),
93        );
94
95        Ok(())
96    }
97
98    pub fn add(&mut self, interrupt: Interrupt) {
99        self.interrupts.insert(interrupt.id.clone(), interrupt);
100    }
101
102    pub fn get(&self, id: &str) -> Option<&Interrupt> {
103        self.interrupts.get(id)
104    }
105
106    pub fn get_response(&self, id: &str) -> Option<&serde_json::Value> {
107        self.interrupts.get(id).and_then(|i| i.response.as_ref())
108    }
109
110    pub fn set_response(&mut self, id: &str, response: serde_json::Value) {
111        if let Some(interrupt) = self.interrupts.get_mut(id) {
112            interrupt.response = Some(response);
113        }
114    }
115
116    pub fn pending_interrupts(&self) -> Vec<&Interrupt> {
117        self.interrupts
118            .values()
119            .filter(|i| i.response.is_none())
120            .collect()
121    }
122
123    pub fn has_pending(&self) -> bool {
124        self.interrupts.values().any(|i| i.response.is_none())
125    }
126
127    pub fn to_dict(&self) -> HashMap<String, serde_json::Value> {
128        let mut dict = HashMap::new();
129        dict.insert(
130            "interrupts".to_string(),
131            serde_json::json!(self.interrupts
132                .iter()
133                .map(|(k, v)| (k.clone(), serde_json::to_value(v).unwrap_or_default()))
134                .collect::<HashMap<_, _>>()),
135        );
136        dict.insert("context".to_string(), serde_json::json!(self.context));
137        dict.insert("activated".to_string(), serde_json::json!(self.activated));
138        dict
139    }
140
141    pub fn from_dict(data: HashMap<String, serde_json::Value>) -> Self {
142        let interrupts = data
143            .get("interrupts")
144            .and_then(|v| v.as_object())
145            .map(|obj| {
146                obj.iter()
147                    .filter_map(|(k, v)| {
148                        serde_json::from_value::<Interrupt>(v.clone())
149                            .ok()
150                            .map(|i| (k.clone(), i))
151                    })
152                    .collect()
153            })
154            .unwrap_or_default();
155
156        let context = data
157            .get("context")
158            .and_then(|v| v.as_object())
159            .map(|obj| {
160                obj.iter()
161                    .map(|(k, v)| (k.clone(), v.clone()))
162                    .collect()
163            })
164            .unwrap_or_default();
165
166        let activated = data
167            .get("activated")
168            .and_then(|v| v.as_bool())
169            .unwrap_or(false);
170
171        Self {
172            interrupts,
173            context,
174            activated,
175        }
176    }
177}
178
179/// User response to an interrupt.
180#[derive(Debug, Clone, Serialize, Deserialize)]
181#[serde(rename_all = "camelCase")]
182pub struct InterruptResponse {
183    pub interrupt_id: String,
184    pub response: serde_json::Value,
185}
186
187/// Content block containing a user response to an interrupt.
188#[derive(Debug, Clone, Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct InterruptResponseContent {
191    pub interrupt_response: InterruptResponse,
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_interrupt_creation() {
200        let interrupt = Interrupt::new("int-1", "approval")
201            .with_reason(serde_json::json!({"type": "delete"}));
202
203        assert_eq!(interrupt.id, "int-1");
204        assert_eq!(interrupt.name, "approval");
205        assert!(interrupt.reason.is_some());
206        assert!(!interrupt.has_response());
207    }
208
209    #[test]
210    fn test_interrupt_state() {
211        let mut state = InterruptState::new();
212        state.add(Interrupt::new("int-1", "approval"));
213        state.add(Interrupt::new("int-2", "confirmation"));
214
215        assert!(state.has_pending());
216        assert_eq!(state.pending_interrupts().len(), 2);
217
218        state.set_response("int-1", serde_json::json!("approved"));
219        assert_eq!(state.pending_interrupts().len(), 1);
220    }
221}
222