Skip to main content

serdes_ai_tools/
deferred.rs

1//! Deferred tool execution for approval flows.
2//!
3//! This module provides types for handling tool calls that require
4//! human approval before execution.
5
6use serde::{Deserialize, Serialize};
7
8use crate::ToolReturn;
9
10/// A tool call that was deferred for approval.
11///
12/// When a tool requires approval (via `ToolError::ApprovalRequired`),
13/// the call is captured as a `DeferredToolCall` for later processing.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct DeferredToolCall {
16    /// Name of the tool.
17    pub tool_name: String,
18    /// Arguments for the tool call.
19    pub args: serde_json::Value,
20    /// Tool call ID.
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub tool_call_id: Option<String>,
23}
24
25impl DeferredToolCall {
26    /// Create a new deferred tool call.
27    #[must_use]
28    pub fn new(tool_name: impl Into<String>, args: serde_json::Value) -> Self {
29        Self {
30            tool_name: tool_name.into(),
31            args,
32            tool_call_id: None,
33        }
34    }
35
36    /// Set the tool call ID.
37    #[must_use]
38    pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
39        self.tool_call_id = Some(id.into());
40        self
41    }
42
43    /// Create an approved decision.
44    #[must_use]
45    pub fn approve(&self) -> DeferredToolDecision {
46        DeferredToolDecision::Approved
47    }
48
49    /// Create a denied decision.
50    #[must_use]
51    pub fn deny(&self, message: impl Into<String>) -> DeferredToolDecision {
52        DeferredToolDecision::Denied(message.into())
53    }
54
55    /// Create a custom result decision.
56    #[must_use]
57    pub fn with_result(&self, result: ToolReturn) -> DeferredToolDecision {
58        DeferredToolDecision::CustomResult(result)
59    }
60}
61
62/// Collection of deferred tool calls.
63#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64pub struct DeferredToolRequests {
65    /// The deferred calls.
66    pub calls: Vec<DeferredToolCall>,
67}
68
69impl DeferredToolRequests {
70    /// Create a new empty collection.
71    #[must_use]
72    pub fn new() -> Self {
73        Self { calls: Vec::new() }
74    }
75
76    /// Add a deferred call.
77    pub fn add(&mut self, call: DeferredToolCall) {
78        self.calls.push(call);
79    }
80
81    /// Check if empty.
82    #[must_use]
83    pub fn is_empty(&self) -> bool {
84        self.calls.is_empty()
85    }
86
87    /// Get the number of deferred calls.
88    #[must_use]
89    pub fn len(&self) -> usize {
90        self.calls.len()
91    }
92
93    /// Get a call by index.
94    #[must_use]
95    pub fn get(&self, index: usize) -> Option<&DeferredToolCall> {
96        self.calls.get(index)
97    }
98
99    /// Iterate over the calls.
100    pub fn iter(&self) -> impl Iterator<Item = &DeferredToolCall> {
101        self.calls.iter()
102    }
103
104    /// Get calls by tool name.
105    #[must_use]
106    pub fn by_tool(&self, name: &str) -> Vec<&DeferredToolCall> {
107        self.calls.iter().filter(|c| c.tool_name == name).collect()
108    }
109
110    /// Clear all calls.
111    pub fn clear(&mut self) {
112        self.calls.clear();
113    }
114
115    /// Approve all calls.
116    #[must_use]
117    pub fn approve_all(&self) -> DeferredToolDecisions {
118        DeferredToolDecisions {
119            decisions: self
120                .calls
121                .iter()
122                .map(|_| DeferredToolDecision::Approved)
123                .collect(),
124        }
125    }
126
127    /// Deny all calls with a message.
128    #[must_use]
129    pub fn deny_all(&self, message: impl Into<String>) -> DeferredToolDecisions {
130        let msg = message.into();
131        DeferredToolDecisions {
132            decisions: self
133                .calls
134                .iter()
135                .map(|_| DeferredToolDecision::Denied(msg.clone()))
136                .collect(),
137        }
138    }
139}
140
141impl FromIterator<DeferredToolCall> for DeferredToolRequests {
142    fn from_iter<T: IntoIterator<Item = DeferredToolCall>>(iter: T) -> Self {
143        Self {
144            calls: iter.into_iter().collect(),
145        }
146    }
147}
148
149/// Decision about a deferred tool call.
150#[derive(Debug, Clone)]
151pub enum DeferredToolDecision {
152    /// Approve the tool call.
153    Approved,
154    /// Deny with a message to send back to the model.
155    Denied(String),
156    /// Provide a custom result.
157    CustomResult(ToolReturn),
158}
159
160impl DeferredToolDecision {
161    /// Check if approved.
162    #[must_use]
163    pub fn is_approved(&self) -> bool {
164        matches!(self, Self::Approved)
165    }
166
167    /// Check if denied.
168    #[must_use]
169    pub fn is_denied(&self) -> bool {
170        matches!(self, Self::Denied(_))
171    }
172
173    /// Check if custom result.
174    #[must_use]
175    pub fn is_custom(&self) -> bool {
176        matches!(self, Self::CustomResult(_))
177    }
178
179    /// Get the denial message if denied.
180    #[must_use]
181    pub fn denial_message(&self) -> Option<&str> {
182        match self {
183            Self::Denied(msg) => Some(msg),
184            _ => None,
185        }
186    }
187}
188
189/// Collection of decisions for deferred tools.
190#[derive(Debug, Clone, Default)]
191pub struct DeferredToolDecisions {
192    /// The decisions.
193    pub decisions: Vec<DeferredToolDecision>,
194}
195
196impl DeferredToolDecisions {
197    /// Create a new empty collection.
198    #[must_use]
199    pub fn new() -> Self {
200        Self {
201            decisions: Vec::new(),
202        }
203    }
204
205    /// Add a decision.
206    pub fn add(&mut self, decision: DeferredToolDecision) {
207        self.decisions.push(decision);
208    }
209
210    /// Check if empty.
211    #[must_use]
212    pub fn is_empty(&self) -> bool {
213        self.decisions.is_empty()
214    }
215
216    /// Get the number of decisions.
217    #[must_use]
218    pub fn len(&self) -> usize {
219        self.decisions.len()
220    }
221
222    /// Check if all are approved.
223    #[must_use]
224    pub fn all_approved(&self) -> bool {
225        self.decisions.iter().all(|d| d.is_approved())
226    }
227
228    /// Check if any are denied.
229    #[must_use]
230    pub fn any_denied(&self) -> bool {
231        self.decisions.iter().any(|d| d.is_denied())
232    }
233}
234
235impl FromIterator<DeferredToolDecision> for DeferredToolDecisions {
236    fn from_iter<T: IntoIterator<Item = DeferredToolDecision>>(iter: T) -> Self {
237        Self {
238            decisions: iter.into_iter().collect(),
239        }
240    }
241}
242
243/// Result for a single deferred tool.
244#[derive(Debug, Clone)]
245pub struct DeferredToolResult {
246    /// Tool call ID.
247    pub tool_call_id: Option<String>,
248    /// The result.
249    pub result: ToolReturn,
250}
251
252impl DeferredToolResult {
253    /// Create a new result.
254    #[must_use]
255    pub fn new(result: ToolReturn) -> Self {
256        Self {
257            tool_call_id: None,
258            result,
259        }
260    }
261
262    /// Set the tool call ID.
263    #[must_use]
264    pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
265        self.tool_call_id = Some(id.into());
266        self
267    }
268
269    /// Create an approved result.
270    #[must_use]
271    pub fn approved() -> Self {
272        Self::new(ToolReturn::text("Tool execution approved"))
273    }
274
275    /// Create a denied result.
276    #[must_use]
277    pub fn denied(message: impl Into<String>) -> Self {
278        Self::new(ToolReturn::error(message))
279    }
280}
281
282/// Results for all deferred tools.
283#[derive(Debug, Clone, Default)]
284pub struct DeferredToolResults {
285    /// The results.
286    pub results: Vec<DeferredToolResult>,
287}
288
289impl DeferredToolResults {
290    /// Create a new empty collection.
291    #[must_use]
292    pub fn new() -> Self {
293        Self {
294            results: Vec::new(),
295        }
296    }
297
298    /// Add a result.
299    pub fn add(&mut self, result: DeferredToolResult) {
300        self.results.push(result);
301    }
302
303    /// Create a single approved result.
304    #[must_use]
305    pub fn approved(id: Option<String>) -> Self {
306        let mut result = DeferredToolResult::approved();
307        if let Some(id) = id {
308            result = result.with_tool_call_id(id);
309        }
310        Self {
311            results: vec![result],
312        }
313    }
314
315    /// Create a single denied result.
316    #[must_use]
317    pub fn denied(id: Option<String>, message: impl Into<String>) -> Self {
318        let mut result = DeferredToolResult::denied(message);
319        if let Some(id) = id {
320            result = result.with_tool_call_id(id);
321        }
322        Self {
323            results: vec![result],
324        }
325    }
326
327    /// Check if empty.
328    #[must_use]
329    pub fn is_empty(&self) -> bool {
330        self.results.is_empty()
331    }
332
333    /// Get the number of results.
334    #[must_use]
335    pub fn len(&self) -> usize {
336        self.results.len()
337    }
338}
339
340impl FromIterator<DeferredToolResult> for DeferredToolResults {
341    fn from_iter<T: IntoIterator<Item = DeferredToolResult>>(iter: T) -> Self {
342        Self {
343            results: iter.into_iter().collect(),
344        }
345    }
346}
347
348/// Marker type for approved tools.
349#[derive(Debug, Clone, Copy, PartialEq, Eq)]
350pub struct ToolApproved;
351
352/// Marker type for denied tools.
353#[derive(Debug, Clone)]
354pub struct ToolDenied {
355    /// The denial message.
356    pub message: String,
357}
358
359impl ToolDenied {
360    /// Create a new denial.
361    #[must_use]
362    pub fn new(message: impl Into<String>) -> Self {
363        Self {
364            message: message.into(),
365        }
366    }
367}
368
369/// Trait for types that can handle tool approval requests.
370#[allow(async_fn_in_trait)]
371pub trait ToolApprover {
372    /// Handle an approval request for a tool call.
373    ///
374    /// Returns the decision for this tool call.
375    async fn approve(&self, call: &DeferredToolCall) -> DeferredToolDecision;
376}
377
378/// Auto-approve all tool calls.
379#[derive(Debug, Clone, Copy, Default)]
380pub struct AutoApprover;
381
382impl ToolApprover for AutoApprover {
383    async fn approve(&self, _call: &DeferredToolCall) -> DeferredToolDecision {
384        DeferredToolDecision::Approved
385    }
386}
387
388/// Auto-deny all tool calls.
389#[derive(Debug, Clone)]
390pub struct AutoDenier {
391    message: String,
392}
393
394impl AutoDenier {
395    /// Create a new auto-denier with the given message.
396    #[must_use]
397    pub fn new(message: impl Into<String>) -> Self {
398        Self {
399            message: message.into(),
400        }
401    }
402}
403
404impl ToolApprover for AutoDenier {
405    async fn approve(&self, _call: &DeferredToolCall) -> DeferredToolDecision {
406        DeferredToolDecision::Denied(self.message.clone())
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_deferred_tool_call() {
416        let call = DeferredToolCall::new("my_tool", serde_json::json!({"x": 1}))
417            .with_tool_call_id("call_123");
418
419        assert_eq!(call.tool_name, "my_tool");
420        assert_eq!(call.tool_call_id, Some("call_123".to_string()));
421    }
422
423    #[test]
424    fn test_deferred_tool_requests() {
425        let mut requests = DeferredToolRequests::new();
426        assert!(requests.is_empty());
427
428        requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
429        requests.add(DeferredToolCall::new("tool2", serde_json::json!({})));
430
431        assert_eq!(requests.len(), 2);
432        assert!(!requests.is_empty());
433    }
434
435    #[test]
436    fn test_by_tool() {
437        let mut requests = DeferredToolRequests::new();
438        requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
439        requests.add(DeferredToolCall::new("tool2", serde_json::json!({})));
440        requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
441
442        let tool1_calls = requests.by_tool("tool1");
443        assert_eq!(tool1_calls.len(), 2);
444    }
445
446    #[test]
447    fn test_approve_all() {
448        let mut requests = DeferredToolRequests::new();
449        requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
450        requests.add(DeferredToolCall::new("tool2", serde_json::json!({})));
451
452        let decisions = requests.approve_all();
453        assert_eq!(decisions.len(), 2);
454        assert!(decisions.all_approved());
455    }
456
457    #[test]
458    fn test_deny_all() {
459        let mut requests = DeferredToolRequests::new();
460        requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
461
462        let decisions = requests.deny_all("Not allowed");
463        assert!(decisions.any_denied());
464    }
465
466    #[test]
467    fn test_deferred_tool_decision() {
468        let approved = DeferredToolDecision::Approved;
469        assert!(approved.is_approved());
470        assert!(!approved.is_denied());
471
472        let denied = DeferredToolDecision::Denied("No".into());
473        assert!(denied.is_denied());
474        assert_eq!(denied.denial_message(), Some("No"));
475
476        let custom = DeferredToolDecision::CustomResult(ToolReturn::text("custom"));
477        assert!(custom.is_custom());
478    }
479
480    #[test]
481    fn test_deferred_tool_result() {
482        let result = DeferredToolResult::approved().with_tool_call_id("id1");
483        assert_eq!(result.tool_call_id, Some("id1".to_string()));
484
485        let denied = DeferredToolResult::denied("Not allowed");
486        assert!(denied.result.is_error());
487    }
488
489    #[test]
490    fn test_deferred_tool_results() {
491        let results = DeferredToolResults::approved(Some("id1".to_string()));
492        assert_eq!(results.len(), 1);
493
494        let denied = DeferredToolResults::denied(None, "Nope");
495        assert_eq!(denied.len(), 1);
496    }
497
498    #[test]
499    fn test_tool_denied() {
500        let denied = ToolDenied::new("Not allowed");
501        assert_eq!(denied.message, "Not allowed");
502    }
503
504    #[tokio::test]
505    async fn test_auto_approver() {
506        let approver = AutoApprover;
507        let call = DeferredToolCall::new("test", serde_json::json!({}));
508        let decision = approver.approve(&call).await;
509        assert!(decision.is_approved());
510    }
511
512    #[tokio::test]
513    async fn test_auto_denier() {
514        let denier = AutoDenier::new("Denied");
515        let call = DeferredToolCall::new("test", serde_json::json!({}));
516        let decision = denier.approve(&call).await;
517        assert!(decision.is_denied());
518    }
519
520    #[test]
521    fn test_serde_roundtrip() {
522        let call =
523            DeferredToolCall::new("test", serde_json::json!({"x": 1})).with_tool_call_id("id");
524        let json = serde_json::to_string(&call).unwrap();
525        let parsed: DeferredToolCall = serde_json::from_str(&json).unwrap();
526        assert_eq!(call.tool_name, parsed.tool_name);
527        assert_eq!(call.tool_call_id, parsed.tool_call_id);
528    }
529}