strands_agents/types/
interrupt.rs1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Interrupt {
10 pub id: String,
11 pub name: String,
12 #[serde(skip_serializing_if = "Option::is_none")]
13 pub reason: Option<serde_json::Value>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub response: Option<serde_json::Value>,
16}
17
18impl Interrupt {
19 pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
20 Self {
21 id: id.into(),
22 name: name.into(),
23 reason: None,
24 response: None,
25 }
26 }
27
28 pub fn with_reason(mut self, reason: serde_json::Value) -> Self {
29 self.reason = Some(reason);
30 self
31 }
32
33 pub fn with_response(mut self, response: serde_json::Value) -> Self {
34 self.response = Some(response);
35 self
36 }
37
38 pub fn has_response(&self) -> bool {
39 self.response.is_some()
40 }
41}
42
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
45pub struct InterruptState {
46 pub interrupts: HashMap<String, Interrupt>,
48 pub context: HashMap<String, serde_json::Value>,
50 pub activated: bool,
52}
53
54impl InterruptState {
55 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn activate(&mut self) {
61 self.activated = true;
62 }
63
64 pub fn deactivate(&mut self) {
68 self.interrupts.clear();
69 self.context.clear();
70 self.activated = false;
71 }
72
73 pub fn resume(&mut self, responses: Vec<InterruptResponseContent>) -> Result<(), String> {
75 if !self.activated {
76 return Ok(());
77 }
78
79 for content in &responses {
80 let interrupt_id = &content.interrupt_response.interrupt_id;
81 let interrupt_response = &content.interrupt_response.response;
82
83 if let Some(interrupt) = self.interrupts.get_mut(interrupt_id) {
84 interrupt.response = Some(interrupt_response.clone());
85 } else {
86 return Err(format!("interrupt_id=<{}> | no interrupt found", interrupt_id));
87 }
88 }
89
90 self.context.insert(
91 "responses".to_string(),
92 serde_json::to_value(&responses).unwrap_or_default(),
93 );
94
95 Ok(())
96 }
97
98 pub fn add(&mut self, interrupt: Interrupt) {
99 self.interrupts.insert(interrupt.id.clone(), interrupt);
100 }
101
102 pub fn get(&self, id: &str) -> Option<&Interrupt> {
103 self.interrupts.get(id)
104 }
105
106 pub fn get_response(&self, id: &str) -> Option<&serde_json::Value> {
107 self.interrupts.get(id).and_then(|i| i.response.as_ref())
108 }
109
110 pub fn set_response(&mut self, id: &str, response: serde_json::Value) {
111 if let Some(interrupt) = self.interrupts.get_mut(id) {
112 interrupt.response = Some(response);
113 }
114 }
115
116 pub fn pending_interrupts(&self) -> Vec<&Interrupt> {
117 self.interrupts
118 .values()
119 .filter(|i| i.response.is_none())
120 .collect()
121 }
122
123 pub fn has_pending(&self) -> bool {
124 self.interrupts.values().any(|i| i.response.is_none())
125 }
126
127 pub fn to_dict(&self) -> HashMap<String, serde_json::Value> {
128 let mut dict = HashMap::new();
129 dict.insert(
130 "interrupts".to_string(),
131 serde_json::json!(self.interrupts
132 .iter()
133 .map(|(k, v)| (k.clone(), serde_json::to_value(v).unwrap_or_default()))
134 .collect::<HashMap<_, _>>()),
135 );
136 dict.insert("context".to_string(), serde_json::json!(self.context));
137 dict.insert("activated".to_string(), serde_json::json!(self.activated));
138 dict
139 }
140
141 pub fn from_dict(data: HashMap<String, serde_json::Value>) -> Self {
142 let interrupts = data
143 .get("interrupts")
144 .and_then(|v| v.as_object())
145 .map(|obj| {
146 obj.iter()
147 .filter_map(|(k, v)| {
148 serde_json::from_value::<Interrupt>(v.clone())
149 .ok()
150 .map(|i| (k.clone(), i))
151 })
152 .collect()
153 })
154 .unwrap_or_default();
155
156 let context = data
157 .get("context")
158 .and_then(|v| v.as_object())
159 .map(|obj| {
160 obj.iter()
161 .map(|(k, v)| (k.clone(), v.clone()))
162 .collect()
163 })
164 .unwrap_or_default();
165
166 let activated = data
167 .get("activated")
168 .and_then(|v| v.as_bool())
169 .unwrap_or(false);
170
171 Self {
172 interrupts,
173 context,
174 activated,
175 }
176 }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181#[serde(rename_all = "camelCase")]
182pub struct InterruptResponse {
183 pub interrupt_id: String,
184 pub response: serde_json::Value,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189#[serde(rename_all = "camelCase")]
190pub struct InterruptResponseContent {
191 pub interrupt_response: InterruptResponse,
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_interrupt_creation() {
200 let interrupt = Interrupt::new("int-1", "approval")
201 .with_reason(serde_json::json!({"type": "delete"}));
202
203 assert_eq!(interrupt.id, "int-1");
204 assert_eq!(interrupt.name, "approval");
205 assert!(interrupt.reason.is_some());
206 assert!(!interrupt.has_response());
207 }
208
209 #[test]
210 fn test_interrupt_state() {
211 let mut state = InterruptState::new();
212 state.add(Interrupt::new("int-1", "approval"));
213 state.add(Interrupt::new("int-2", "confirmation"));
214
215 assert!(state.has_pending());
216 assert_eq!(state.pending_interrupts().len(), 2);
217
218 state.set_response("int-1", serde_json::json!("approved"));
219 assert_eq!(state.pending_interrupts().len(), 1);
220 }
221}
222