Skip to main content

stakpak_agent_core/
approval.rs

1use crate::types::{
2    AgentCommand, ProposedToolCall, ToolApprovalAction, ToolApprovalPolicy, ToolDecision,
3};
4use thiserror::Error;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7enum ApprovalEntryState {
8    PendingUserDecision,
9    Ready(ToolDecision),
10    Dispatched,
11}
12
13#[derive(Debug, Clone, PartialEq)]
14struct ApprovalEntry {
15    tool_call: ProposedToolCall,
16    state: ApprovalEntryState,
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub struct ResolvedToolCall {
21    pub tool_call: ProposedToolCall,
22    pub decision: ToolDecision,
23}
24
25#[derive(Debug, Clone)]
26pub struct ApprovalStateMachine {
27    entries: Vec<ApprovalEntry>,
28    next_index: usize,
29}
30
31#[derive(Debug, Error, Clone, PartialEq, Eq)]
32pub enum ApprovalError {
33    #[error("unknown tool_call_id: {tool_call_id}")]
34    UnknownToolCallId { tool_call_id: String },
35
36    #[error("tool_call_id {tool_call_id} is already resolved")]
37    AlreadyResolved { tool_call_id: String },
38
39    #[error("invalid approval command for state machine")]
40    InvalidCommand,
41}
42
43impl ApprovalStateMachine {
44    pub fn new(tool_calls: Vec<ProposedToolCall>, policy: &ToolApprovalPolicy) -> Self {
45        let entries = tool_calls
46            .into_iter()
47            .map(|tool_call| {
48                let initial_state = match policy
49                    .action_for(&tool_call.name, Some(&tool_call.arguments))
50                {
51                    ToolApprovalAction::Approve => ApprovalEntryState::Ready(ToolDecision::Accept),
52                    ToolApprovalAction::Deny => ApprovalEntryState::Ready(ToolDecision::Reject),
53                    ToolApprovalAction::Ask => ApprovalEntryState::PendingUserDecision,
54                };
55
56                ApprovalEntry {
57                    tool_call,
58                    state: initial_state,
59                }
60            })
61            .collect();
62
63        Self {
64            entries,
65            next_index: 0,
66        }
67    }
68
69    pub fn pending_tool_call_ids(&self) -> Vec<String> {
70        self.entries
71            .iter()
72            .filter_map(|entry| {
73                if matches!(entry.state, ApprovalEntryState::PendingUserDecision) {
74                    Some(entry.tool_call.id.clone())
75                } else {
76                    None
77                }
78            })
79            .collect()
80    }
81
82    pub fn is_waiting_for_user(&self) -> bool {
83        self.entries
84            .iter()
85            .any(|entry| matches!(entry.state, ApprovalEntryState::PendingUserDecision))
86    }
87
88    pub fn is_complete(&self) -> bool {
89        self.next_index >= self.entries.len()
90    }
91
92    pub fn apply_command(&mut self, command: AgentCommand) -> Result<(), ApprovalError> {
93        match command {
94            AgentCommand::ResolveTool {
95                tool_call_id,
96                decision,
97            } => self.resolve_tool(&tool_call_id, decision),
98            AgentCommand::ResolveTools { decisions } => {
99                for (tool_call_id, decision) in decisions {
100                    self.resolve_tool(&tool_call_id, decision)?;
101                }
102                Ok(())
103            }
104            _ => Err(ApprovalError::InvalidCommand),
105        }
106    }
107
108    pub fn resolve_tool(
109        &mut self,
110        tool_call_id: &str,
111        decision: ToolDecision,
112    ) -> Result<(), ApprovalError> {
113        let maybe_entry = self
114            .entries
115            .iter_mut()
116            .find(|entry| entry.tool_call.id == tool_call_id);
117
118        let Some(entry) = maybe_entry else {
119            return Err(ApprovalError::UnknownToolCallId {
120                tool_call_id: tool_call_id.to_string(),
121            });
122        };
123
124        match &entry.state {
125            ApprovalEntryState::PendingUserDecision => {
126                entry.state = ApprovalEntryState::Ready(decision);
127                Ok(())
128            }
129            ApprovalEntryState::Ready(existing) if *existing == decision => Ok(()),
130            ApprovalEntryState::Ready(_) | ApprovalEntryState::Dispatched => {
131                Err(ApprovalError::AlreadyResolved {
132                    tool_call_id: tool_call_id.to_string(),
133                })
134            }
135        }
136    }
137
138    pub fn next_ready(&mut self) -> Option<ResolvedToolCall> {
139        while self.next_index < self.entries.len() {
140            let entry = self.entries.get_mut(self.next_index)?;
141
142            match &entry.state {
143                ApprovalEntryState::PendingUserDecision => return None,
144                ApprovalEntryState::Ready(decision) => {
145                    let resolved = ResolvedToolCall {
146                        tool_call: entry.tool_call.clone(),
147                        decision: decision.clone(),
148                    };
149                    entry.state = ApprovalEntryState::Dispatched;
150                    self.next_index += 1;
151                    return Some(resolved);
152                }
153                ApprovalEntryState::Dispatched => {
154                    self.next_index += 1;
155                }
156            }
157        }
158
159        None
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::types::ToolApprovalAction;
167    use serde_json::json;
168    use std::collections::HashMap;
169
170    fn tool_call(id: &str, name: &str) -> ProposedToolCall {
171        ProposedToolCall {
172            id: id.to_string(),
173            name: name.to_string(),
174            arguments: json!({"input": id}),
175            metadata: None,
176        }
177    }
178
179    #[test]
180    fn incremental_decisions_buffer_until_prior_call_is_resolved() {
181        let calls = vec![tool_call("tc_1", "tool_a"), tool_call("tc_2", "tool_b")];
182        let mut machine = ApprovalStateMachine::new(calls, &ToolApprovalPolicy::None);
183
184        let out_of_order = machine.resolve_tool("tc_2", ToolDecision::Accept);
185        assert!(out_of_order.is_ok());
186
187        assert!(machine.next_ready().is_none());
188
189        let first_resolution = machine.resolve_tool("tc_1", ToolDecision::Reject);
190        assert!(first_resolution.is_ok());
191
192        let first = machine.next_ready();
193        assert_eq!(
194            first,
195            Some(ResolvedToolCall {
196                tool_call: tool_call("tc_1", "tool_a"),
197                decision: ToolDecision::Reject,
198            })
199        );
200
201        let second = machine.next_ready();
202        assert_eq!(
203            second,
204            Some(ResolvedToolCall {
205                tool_call: tool_call("tc_2", "tool_b"),
206                decision: ToolDecision::Accept,
207            })
208        );
209
210        assert!(machine.is_complete());
211    }
212
213    #[test]
214    fn bulk_resolution_resolves_multiple_calls() {
215        let calls = vec![
216            tool_call("tc_1", "tool_a"),
217            tool_call("tc_2", "tool_b"),
218            tool_call("tc_3", "tool_c"),
219        ];
220        let mut machine = ApprovalStateMachine::new(calls, &ToolApprovalPolicy::None);
221
222        let mut decisions = HashMap::new();
223        decisions.insert("tc_1".to_string(), ToolDecision::Accept);
224        decisions.insert("tc_2".to_string(), ToolDecision::Reject);
225
226        let command_result = machine.apply_command(AgentCommand::ResolveTools { decisions });
227        assert!(command_result.is_ok());
228
229        assert_eq!(
230            machine.next_ready(),
231            Some(ResolvedToolCall {
232                tool_call: tool_call("tc_1", "tool_a"),
233                decision: ToolDecision::Accept,
234            })
235        );
236
237        assert_eq!(
238            machine.next_ready(),
239            Some(ResolvedToolCall {
240                tool_call: tool_call("tc_2", "tool_b"),
241                decision: ToolDecision::Reject,
242            })
243        );
244
245        assert!(machine.next_ready().is_none());
246        assert_eq!(machine.pending_tool_call_ids(), vec!["tc_3".to_string()]);
247    }
248
249    #[test]
250    fn policy_applies_auto_approve_and_auto_deny() {
251        let calls = vec![
252            tool_call("tc_1", "safe_tool"),
253            tool_call("tc_2", "danger_tool"),
254            tool_call("tc_3", "unknown_tool"),
255        ];
256
257        let mut rules = HashMap::new();
258        rules.insert("safe_tool".to_string(), ToolApprovalAction::Approve);
259        rules.insert("danger_tool".to_string(), ToolApprovalAction::Deny);
260
261        let policy = ToolApprovalPolicy::Custom {
262            rules,
263            default: ToolApprovalAction::Ask,
264        };
265
266        let mut machine = ApprovalStateMachine::new(calls, &policy);
267
268        assert_eq!(
269            machine.next_ready(),
270            Some(ResolvedToolCall {
271                tool_call: tool_call("tc_1", "safe_tool"),
272                decision: ToolDecision::Accept,
273            })
274        );
275
276        assert_eq!(
277            machine.next_ready(),
278            Some(ResolvedToolCall {
279                tool_call: tool_call("tc_2", "danger_tool"),
280                decision: ToolDecision::Reject,
281            })
282        );
283
284        assert!(machine.next_ready().is_none());
285        assert_eq!(machine.pending_tool_call_ids(), vec!["tc_3".to_string()]);
286    }
287
288    #[test]
289    fn resolve_unknown_tool_call_returns_error() {
290        let calls = vec![tool_call("tc_1", "tool_a")];
291        let mut machine = ApprovalStateMachine::new(calls, &ToolApprovalPolicy::None);
292
293        let error = machine.resolve_tool("tc_missing", ToolDecision::Accept);
294
295        assert_eq!(
296            error,
297            Err(ApprovalError::UnknownToolCallId {
298                tool_call_id: "tc_missing".to_string(),
299            })
300        );
301    }
302
303    #[test]
304    fn resolve_same_decision_is_idempotent() {
305        let calls = vec![tool_call("tc_1", "tool_a")];
306        let mut machine = ApprovalStateMachine::new(calls, &ToolApprovalPolicy::None);
307
308        assert!(machine.resolve_tool("tc_1", ToolDecision::Accept).is_ok());
309        assert!(machine.resolve_tool("tc_1", ToolDecision::Accept).is_ok());
310
311        assert_eq!(
312            machine.next_ready(),
313            Some(ResolvedToolCall {
314                tool_call: tool_call("tc_1", "tool_a"),
315                decision: ToolDecision::Accept,
316            })
317        );
318    }
319}