1use crate::types::{
2 AgentCommand, ProposedToolCall, ToolApprovalAction, ToolApprovalPolicy, ToolDecision,
3};
4use thiserror::Error;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7enum ApprovalEntryState {
8 PendingUserDecision,
9 Ready(ToolDecision),
10 Dispatched,
11}
12
13#[derive(Debug, Clone, PartialEq)]
14struct ApprovalEntry {
15 tool_call: ProposedToolCall,
16 state: ApprovalEntryState,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub struct ResolvedToolCall {
21 pub tool_call: ProposedToolCall,
22 pub decision: ToolDecision,
23}
24
25#[derive(Debug, Clone)]
26pub struct ApprovalStateMachine {
27 entries: Vec<ApprovalEntry>,
28 next_index: usize,
29}
30
31#[derive(Debug, Error, Clone, PartialEq, Eq)]
32pub enum ApprovalError {
33 #[error("unknown tool_call_id: {tool_call_id}")]
34 UnknownToolCallId { tool_call_id: String },
35
36 #[error("tool_call_id {tool_call_id} is already resolved")]
37 AlreadyResolved { tool_call_id: String },
38
39 #[error("invalid approval command for state machine")]
40 InvalidCommand,
41}
42
43impl ApprovalStateMachine {
44 pub fn new(tool_calls: Vec<ProposedToolCall>, policy: &ToolApprovalPolicy) -> Self {
45 let entries = tool_calls
46 .into_iter()
47 .map(|tool_call| {
48 let initial_state = match policy
49 .action_for(&tool_call.name, Some(&tool_call.arguments))
50 {
51 ToolApprovalAction::Approve => ApprovalEntryState::Ready(ToolDecision::Accept),
52 ToolApprovalAction::Deny => ApprovalEntryState::Ready(ToolDecision::Reject),
53 ToolApprovalAction::Ask => ApprovalEntryState::PendingUserDecision,
54 };
55
56 ApprovalEntry {
57 tool_call,
58 state: initial_state,
59 }
60 })
61 .collect();
62
63 Self {
64 entries,
65 next_index: 0,
66 }
67 }
68
69 pub fn pending_tool_call_ids(&self) -> Vec<String> {
70 self.entries
71 .iter()
72 .filter_map(|entry| {
73 if matches!(entry.state, ApprovalEntryState::PendingUserDecision) {
74 Some(entry.tool_call.id.clone())
75 } else {
76 None
77 }
78 })
79 .collect()
80 }
81
82 pub fn is_waiting_for_user(&self) -> bool {
83 self.entries
84 .iter()
85 .any(|entry| matches!(entry.state, ApprovalEntryState::PendingUserDecision))
86 }
87
88 pub fn is_complete(&self) -> bool {
89 self.next_index >= self.entries.len()
90 }
91
92 pub fn apply_command(&mut self, command: AgentCommand) -> Result<(), ApprovalError> {
93 match command {
94 AgentCommand::ResolveTool {
95 tool_call_id,
96 decision,
97 } => self.resolve_tool(&tool_call_id, decision),
98 AgentCommand::ResolveTools { decisions } => {
99 for (tool_call_id, decision) in decisions {
100 self.resolve_tool(&tool_call_id, decision)?;
101 }
102 Ok(())
103 }
104 _ => Err(ApprovalError::InvalidCommand),
105 }
106 }
107
108 pub fn resolve_tool(
109 &mut self,
110 tool_call_id: &str,
111 decision: ToolDecision,
112 ) -> Result<(), ApprovalError> {
113 let maybe_entry = self
114 .entries
115 .iter_mut()
116 .find(|entry| entry.tool_call.id == tool_call_id);
117
118 let Some(entry) = maybe_entry else {
119 return Err(ApprovalError::UnknownToolCallId {
120 tool_call_id: tool_call_id.to_string(),
121 });
122 };
123
124 match &entry.state {
125 ApprovalEntryState::PendingUserDecision => {
126 entry.state = ApprovalEntryState::Ready(decision);
127 Ok(())
128 }
129 ApprovalEntryState::Ready(existing) if *existing == decision => Ok(()),
130 ApprovalEntryState::Ready(_) | ApprovalEntryState::Dispatched => {
131 Err(ApprovalError::AlreadyResolved {
132 tool_call_id: tool_call_id.to_string(),
133 })
134 }
135 }
136 }
137
138 pub fn next_ready(&mut self) -> Option<ResolvedToolCall> {
139 while self.next_index < self.entries.len() {
140 let entry = self.entries.get_mut(self.next_index)?;
141
142 match &entry.state {
143 ApprovalEntryState::PendingUserDecision => return None,
144 ApprovalEntryState::Ready(decision) => {
145 let resolved = ResolvedToolCall {
146 tool_call: entry.tool_call.clone(),
147 decision: decision.clone(),
148 };
149 entry.state = ApprovalEntryState::Dispatched;
150 self.next_index += 1;
151 return Some(resolved);
152 }
153 ApprovalEntryState::Dispatched => {
154 self.next_index += 1;
155 }
156 }
157 }
158
159 None
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::types::ToolApprovalAction;
167 use serde_json::json;
168 use std::collections::HashMap;
169
170 fn tool_call(id: &str, name: &str) -> ProposedToolCall {
171 ProposedToolCall {
172 id: id.to_string(),
173 name: name.to_string(),
174 arguments: json!({"input": id}),
175 metadata: None,
176 }
177 }
178
179 #[test]
180 fn incremental_decisions_buffer_until_prior_call_is_resolved() {
181 let calls = vec![tool_call("tc_1", "tool_a"), tool_call("tc_2", "tool_b")];
182 let mut machine = ApprovalStateMachine::new(calls, &ToolApprovalPolicy::None);
183
184 let out_of_order = machine.resolve_tool("tc_2", ToolDecision::Accept);
185 assert!(out_of_order.is_ok());
186
187 assert!(machine.next_ready().is_none());
188
189 let first_resolution = machine.resolve_tool("tc_1", ToolDecision::Reject);
190 assert!(first_resolution.is_ok());
191
192 let first = machine.next_ready();
193 assert_eq!(
194 first,
195 Some(ResolvedToolCall {
196 tool_call: tool_call("tc_1", "tool_a"),
197 decision: ToolDecision::Reject,
198 })
199 );
200
201 let second = machine.next_ready();
202 assert_eq!(
203 second,
204 Some(ResolvedToolCall {
205 tool_call: tool_call("tc_2", "tool_b"),
206 decision: ToolDecision::Accept,
207 })
208 );
209
210 assert!(machine.is_complete());
211 }
212
213 #[test]
214 fn bulk_resolution_resolves_multiple_calls() {
215 let calls = vec![
216 tool_call("tc_1", "tool_a"),
217 tool_call("tc_2", "tool_b"),
218 tool_call("tc_3", "tool_c"),
219 ];
220 let mut machine = ApprovalStateMachine::new(calls, &ToolApprovalPolicy::None);
221
222 let mut decisions = HashMap::new();
223 decisions.insert("tc_1".to_string(), ToolDecision::Accept);
224 decisions.insert("tc_2".to_string(), ToolDecision::Reject);
225
226 let command_result = machine.apply_command(AgentCommand::ResolveTools { decisions });
227 assert!(command_result.is_ok());
228
229 assert_eq!(
230 machine.next_ready(),
231 Some(ResolvedToolCall {
232 tool_call: tool_call("tc_1", "tool_a"),
233 decision: ToolDecision::Accept,
234 })
235 );
236
237 assert_eq!(
238 machine.next_ready(),
239 Some(ResolvedToolCall {
240 tool_call: tool_call("tc_2", "tool_b"),
241 decision: ToolDecision::Reject,
242 })
243 );
244
245 assert!(machine.next_ready().is_none());
246 assert_eq!(machine.pending_tool_call_ids(), vec!["tc_3".to_string()]);
247 }
248
249 #[test]
250 fn policy_applies_auto_approve_and_auto_deny() {
251 let calls = vec![
252 tool_call("tc_1", "safe_tool"),
253 tool_call("tc_2", "danger_tool"),
254 tool_call("tc_3", "unknown_tool"),
255 ];
256
257 let mut rules = HashMap::new();
258 rules.insert("safe_tool".to_string(), ToolApprovalAction::Approve);
259 rules.insert("danger_tool".to_string(), ToolApprovalAction::Deny);
260
261 let policy = ToolApprovalPolicy::Custom {
262 rules,
263 default: ToolApprovalAction::Ask,
264 };
265
266 let mut machine = ApprovalStateMachine::new(calls, &policy);
267
268 assert_eq!(
269 machine.next_ready(),
270 Some(ResolvedToolCall {
271 tool_call: tool_call("tc_1", "safe_tool"),
272 decision: ToolDecision::Accept,
273 })
274 );
275
276 assert_eq!(
277 machine.next_ready(),
278 Some(ResolvedToolCall {
279 tool_call: tool_call("tc_2", "danger_tool"),
280 decision: ToolDecision::Reject,
281 })
282 );
283
284 assert!(machine.next_ready().is_none());
285 assert_eq!(machine.pending_tool_call_ids(), vec!["tc_3".to_string()]);
286 }
287
288 #[test]
289 fn resolve_unknown_tool_call_returns_error() {
290 let calls = vec![tool_call("tc_1", "tool_a")];
291 let mut machine = ApprovalStateMachine::new(calls, &ToolApprovalPolicy::None);
292
293 let error = machine.resolve_tool("tc_missing", ToolDecision::Accept);
294
295 assert_eq!(
296 error,
297 Err(ApprovalError::UnknownToolCallId {
298 tool_call_id: "tc_missing".to_string(),
299 })
300 );
301 }
302
303 #[test]
304 fn resolve_same_decision_is_idempotent() {
305 let calls = vec![tool_call("tc_1", "tool_a")];
306 let mut machine = ApprovalStateMachine::new(calls, &ToolApprovalPolicy::None);
307
308 assert!(machine.resolve_tool("tc_1", ToolDecision::Accept).is_ok());
309 assert!(machine.resolve_tool("tc_1", ToolDecision::Accept).is_ok());
310
311 assert_eq!(
312 machine.next_ready(),
313 Some(ResolvedToolCall {
314 tool_call: tool_call("tc_1", "tool_a"),
315 decision: ToolDecision::Accept,
316 })
317 );
318 }
319}