1use serde::{Deserialize, Serialize};
7
8use crate::ToolReturn;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct DeferredToolCall {
16 pub tool_name: String,
18 pub args: serde_json::Value,
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub tool_call_id: Option<String>,
23}
24
25impl DeferredToolCall {
26 #[must_use]
28 pub fn new(tool_name: impl Into<String>, args: serde_json::Value) -> Self {
29 Self {
30 tool_name: tool_name.into(),
31 args,
32 tool_call_id: None,
33 }
34 }
35
36 #[must_use]
38 pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
39 self.tool_call_id = Some(id.into());
40 self
41 }
42
43 #[must_use]
45 pub fn approve(&self) -> DeferredToolDecision {
46 DeferredToolDecision::Approved
47 }
48
49 #[must_use]
51 pub fn deny(&self, message: impl Into<String>) -> DeferredToolDecision {
52 DeferredToolDecision::Denied(message.into())
53 }
54
55 #[must_use]
57 pub fn with_result(&self, result: ToolReturn) -> DeferredToolDecision {
58 DeferredToolDecision::CustomResult(result)
59 }
60}
61
62#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64pub struct DeferredToolRequests {
65 pub calls: Vec<DeferredToolCall>,
67}
68
69impl DeferredToolRequests {
70 #[must_use]
72 pub fn new() -> Self {
73 Self { calls: Vec::new() }
74 }
75
76 pub fn add(&mut self, call: DeferredToolCall) {
78 self.calls.push(call);
79 }
80
81 #[must_use]
83 pub fn is_empty(&self) -> bool {
84 self.calls.is_empty()
85 }
86
87 #[must_use]
89 pub fn len(&self) -> usize {
90 self.calls.len()
91 }
92
93 #[must_use]
95 pub fn get(&self, index: usize) -> Option<&DeferredToolCall> {
96 self.calls.get(index)
97 }
98
99 pub fn iter(&self) -> impl Iterator<Item = &DeferredToolCall> {
101 self.calls.iter()
102 }
103
104 #[must_use]
106 pub fn by_tool(&self, name: &str) -> Vec<&DeferredToolCall> {
107 self.calls.iter().filter(|c| c.tool_name == name).collect()
108 }
109
110 pub fn clear(&mut self) {
112 self.calls.clear();
113 }
114
115 #[must_use]
117 pub fn approve_all(&self) -> DeferredToolDecisions {
118 DeferredToolDecisions {
119 decisions: self
120 .calls
121 .iter()
122 .map(|_| DeferredToolDecision::Approved)
123 .collect(),
124 }
125 }
126
127 #[must_use]
129 pub fn deny_all(&self, message: impl Into<String>) -> DeferredToolDecisions {
130 let msg = message.into();
131 DeferredToolDecisions {
132 decisions: self
133 .calls
134 .iter()
135 .map(|_| DeferredToolDecision::Denied(msg.clone()))
136 .collect(),
137 }
138 }
139}
140
141impl FromIterator<DeferredToolCall> for DeferredToolRequests {
142 fn from_iter<T: IntoIterator<Item = DeferredToolCall>>(iter: T) -> Self {
143 Self {
144 calls: iter.into_iter().collect(),
145 }
146 }
147}
148
149#[derive(Debug, Clone)]
151pub enum DeferredToolDecision {
152 Approved,
154 Denied(String),
156 CustomResult(ToolReturn),
158}
159
160impl DeferredToolDecision {
161 #[must_use]
163 pub fn is_approved(&self) -> bool {
164 matches!(self, Self::Approved)
165 }
166
167 #[must_use]
169 pub fn is_denied(&self) -> bool {
170 matches!(self, Self::Denied(_))
171 }
172
173 #[must_use]
175 pub fn is_custom(&self) -> bool {
176 matches!(self, Self::CustomResult(_))
177 }
178
179 #[must_use]
181 pub fn denial_message(&self) -> Option<&str> {
182 match self {
183 Self::Denied(msg) => Some(msg),
184 _ => None,
185 }
186 }
187}
188
189#[derive(Debug, Clone, Default)]
191pub struct DeferredToolDecisions {
192 pub decisions: Vec<DeferredToolDecision>,
194}
195
196impl DeferredToolDecisions {
197 #[must_use]
199 pub fn new() -> Self {
200 Self {
201 decisions: Vec::new(),
202 }
203 }
204
205 pub fn add(&mut self, decision: DeferredToolDecision) {
207 self.decisions.push(decision);
208 }
209
210 #[must_use]
212 pub fn is_empty(&self) -> bool {
213 self.decisions.is_empty()
214 }
215
216 #[must_use]
218 pub fn len(&self) -> usize {
219 self.decisions.len()
220 }
221
222 #[must_use]
224 pub fn all_approved(&self) -> bool {
225 self.decisions.iter().all(|d| d.is_approved())
226 }
227
228 #[must_use]
230 pub fn any_denied(&self) -> bool {
231 self.decisions.iter().any(|d| d.is_denied())
232 }
233}
234
235impl FromIterator<DeferredToolDecision> for DeferredToolDecisions {
236 fn from_iter<T: IntoIterator<Item = DeferredToolDecision>>(iter: T) -> Self {
237 Self {
238 decisions: iter.into_iter().collect(),
239 }
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct DeferredToolResult {
246 pub tool_call_id: Option<String>,
248 pub result: ToolReturn,
250}
251
252impl DeferredToolResult {
253 #[must_use]
255 pub fn new(result: ToolReturn) -> Self {
256 Self {
257 tool_call_id: None,
258 result,
259 }
260 }
261
262 #[must_use]
264 pub fn with_tool_call_id(mut self, id: impl Into<String>) -> Self {
265 self.tool_call_id = Some(id.into());
266 self
267 }
268
269 #[must_use]
271 pub fn approved() -> Self {
272 Self::new(ToolReturn::text("Tool execution approved"))
273 }
274
275 #[must_use]
277 pub fn denied(message: impl Into<String>) -> Self {
278 Self::new(ToolReturn::error(message))
279 }
280}
281
282#[derive(Debug, Clone, Default)]
284pub struct DeferredToolResults {
285 pub results: Vec<DeferredToolResult>,
287}
288
289impl DeferredToolResults {
290 #[must_use]
292 pub fn new() -> Self {
293 Self {
294 results: Vec::new(),
295 }
296 }
297
298 pub fn add(&mut self, result: DeferredToolResult) {
300 self.results.push(result);
301 }
302
303 #[must_use]
305 pub fn approved(id: Option<String>) -> Self {
306 let mut result = DeferredToolResult::approved();
307 if let Some(id) = id {
308 result = result.with_tool_call_id(id);
309 }
310 Self {
311 results: vec![result],
312 }
313 }
314
315 #[must_use]
317 pub fn denied(id: Option<String>, message: impl Into<String>) -> Self {
318 let mut result = DeferredToolResult::denied(message);
319 if let Some(id) = id {
320 result = result.with_tool_call_id(id);
321 }
322 Self {
323 results: vec![result],
324 }
325 }
326
327 #[must_use]
329 pub fn is_empty(&self) -> bool {
330 self.results.is_empty()
331 }
332
333 #[must_use]
335 pub fn len(&self) -> usize {
336 self.results.len()
337 }
338}
339
340impl FromIterator<DeferredToolResult> for DeferredToolResults {
341 fn from_iter<T: IntoIterator<Item = DeferredToolResult>>(iter: T) -> Self {
342 Self {
343 results: iter.into_iter().collect(),
344 }
345 }
346}
347
348#[derive(Debug, Clone, Copy, PartialEq, Eq)]
350pub struct ToolApproved;
351
352#[derive(Debug, Clone)]
354pub struct ToolDenied {
355 pub message: String,
357}
358
359impl ToolDenied {
360 #[must_use]
362 pub fn new(message: impl Into<String>) -> Self {
363 Self {
364 message: message.into(),
365 }
366 }
367}
368
369#[allow(async_fn_in_trait)]
371pub trait ToolApprover {
372 async fn approve(&self, call: &DeferredToolCall) -> DeferredToolDecision;
376}
377
378#[derive(Debug, Clone, Copy, Default)]
380pub struct AutoApprover;
381
382impl ToolApprover for AutoApprover {
383 async fn approve(&self, _call: &DeferredToolCall) -> DeferredToolDecision {
384 DeferredToolDecision::Approved
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct AutoDenier {
391 message: String,
392}
393
394impl AutoDenier {
395 #[must_use]
397 pub fn new(message: impl Into<String>) -> Self {
398 Self {
399 message: message.into(),
400 }
401 }
402}
403
404impl ToolApprover for AutoDenier {
405 async fn approve(&self, _call: &DeferredToolCall) -> DeferredToolDecision {
406 DeferredToolDecision::Denied(self.message.clone())
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_deferred_tool_call() {
416 let call = DeferredToolCall::new("my_tool", serde_json::json!({"x": 1}))
417 .with_tool_call_id("call_123");
418
419 assert_eq!(call.tool_name, "my_tool");
420 assert_eq!(call.tool_call_id, Some("call_123".to_string()));
421 }
422
423 #[test]
424 fn test_deferred_tool_requests() {
425 let mut requests = DeferredToolRequests::new();
426 assert!(requests.is_empty());
427
428 requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
429 requests.add(DeferredToolCall::new("tool2", serde_json::json!({})));
430
431 assert_eq!(requests.len(), 2);
432 assert!(!requests.is_empty());
433 }
434
435 #[test]
436 fn test_by_tool() {
437 let mut requests = DeferredToolRequests::new();
438 requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
439 requests.add(DeferredToolCall::new("tool2", serde_json::json!({})));
440 requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
441
442 let tool1_calls = requests.by_tool("tool1");
443 assert_eq!(tool1_calls.len(), 2);
444 }
445
446 #[test]
447 fn test_approve_all() {
448 let mut requests = DeferredToolRequests::new();
449 requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
450 requests.add(DeferredToolCall::new("tool2", serde_json::json!({})));
451
452 let decisions = requests.approve_all();
453 assert_eq!(decisions.len(), 2);
454 assert!(decisions.all_approved());
455 }
456
457 #[test]
458 fn test_deny_all() {
459 let mut requests = DeferredToolRequests::new();
460 requests.add(DeferredToolCall::new("tool1", serde_json::json!({})));
461
462 let decisions = requests.deny_all("Not allowed");
463 assert!(decisions.any_denied());
464 }
465
466 #[test]
467 fn test_deferred_tool_decision() {
468 let approved = DeferredToolDecision::Approved;
469 assert!(approved.is_approved());
470 assert!(!approved.is_denied());
471
472 let denied = DeferredToolDecision::Denied("No".into());
473 assert!(denied.is_denied());
474 assert_eq!(denied.denial_message(), Some("No"));
475
476 let custom = DeferredToolDecision::CustomResult(ToolReturn::text("custom"));
477 assert!(custom.is_custom());
478 }
479
480 #[test]
481 fn test_deferred_tool_result() {
482 let result = DeferredToolResult::approved().with_tool_call_id("id1");
483 assert_eq!(result.tool_call_id, Some("id1".to_string()));
484
485 let denied = DeferredToolResult::denied("Not allowed");
486 assert!(denied.result.is_error());
487 }
488
489 #[test]
490 fn test_deferred_tool_results() {
491 let results = DeferredToolResults::approved(Some("id1".to_string()));
492 assert_eq!(results.len(), 1);
493
494 let denied = DeferredToolResults::denied(None, "Nope");
495 assert_eq!(denied.len(), 1);
496 }
497
498 #[test]
499 fn test_tool_denied() {
500 let denied = ToolDenied::new("Not allowed");
501 assert_eq!(denied.message, "Not allowed");
502 }
503
504 #[tokio::test]
505 async fn test_auto_approver() {
506 let approver = AutoApprover;
507 let call = DeferredToolCall::new("test", serde_json::json!({}));
508 let decision = approver.approve(&call).await;
509 assert!(decision.is_approved());
510 }
511
512 #[tokio::test]
513 async fn test_auto_denier() {
514 let denier = AutoDenier::new("Denied");
515 let call = DeferredToolCall::new("test", serde_json::json!({}));
516 let decision = denier.approve(&call).await;
517 assert!(decision.is_denied());
518 }
519
520 #[test]
521 fn test_serde_roundtrip() {
522 let call =
523 DeferredToolCall::new("test", serde_json::json!({"x": 1})).with_tool_call_id("id");
524 let json = serde_json::to_string(&call).unwrap();
525 let parsed: DeferredToolCall = serde_json::from_str(&json).unwrap();
526 assert_eq!(call.tool_name, parsed.tool_name);
527 assert_eq!(call.tool_call_id, parsed.tool_call_id);
528 }
529}