1use std::sync::Arc;
4
5use futures::stream::{self, StreamExt};
6
7use super::{AgentTool, InvocationState, ToolEvent, ToolContext};
8use crate::types::tools::{ToolResult, ToolResultContent, ToolResultStatus, ToolUse};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ExecutionMode {
13 Sequential,
14 Concurrent { max_parallel: Option<usize> },
15}
16
17impl Default for ExecutionMode {
18 fn default() -> Self { Self::Concurrent { max_parallel: None } }
19}
20
21pub struct ToolExecutor {
23 mode: ExecutionMode,
24}
25
26impl Default for ToolExecutor {
27 fn default() -> Self { Self::new() }
28}
29
30impl ToolExecutor {
31 pub fn new() -> Self { Self { mode: ExecutionMode::default() } }
32 pub fn sequential() -> Self { Self { mode: ExecutionMode::Sequential } }
33 pub fn concurrent(max_parallel: Option<usize>) -> Self { Self { mode: ExecutionMode::Concurrent { max_parallel } } }
34
35 pub async fn execute_all(
37 &self,
38 tools: &[(Arc<dyn AgentTool>, ToolUse)],
39 invocation_state: &InvocationState,
40 ) -> Vec<(String, Vec<ToolEvent>)> {
41 match self.mode {
42 ExecutionMode::Sequential => self.execute_sequential(tools, invocation_state).await,
43 ExecutionMode::Concurrent { max_parallel } => self.execute_concurrent(tools, invocation_state, max_parallel).await,
44 }
45 }
46
47 async fn execute_sequential(
48 &self,
49 tools: &[(Arc<dyn AgentTool>, ToolUse)],
50 invocation_state: &InvocationState,
51 ) -> Vec<(String, Vec<ToolEvent>)> {
52 let mut results = Vec::with_capacity(tools.len());
53 for (tool, tool_use) in tools {
54 let events = Self::execute_single(tool.clone(), tool_use, invocation_state).await;
55 results.push((tool_use.tool_use_id.clone(), events));
56 }
57 results
58 }
59
60 async fn execute_concurrent(
61 &self,
62 tools: &[(Arc<dyn AgentTool>, ToolUse)],
63 invocation_state: &InvocationState,
64 max_parallel: Option<usize>,
65 ) -> Vec<(String, Vec<ToolEvent>)> {
66 let limit = max_parallel.unwrap_or(tools.len());
67
68 let futures = tools.iter().map(|(tool, tool_use)| {
69 let tool = tool.clone();
70 let tool_use = tool_use.clone();
71 let state = invocation_state.clone();
72 async move {
73 let events = Self::execute_single(tool, &tool_use, &state).await;
74 (tool_use.tool_use_id, events)
75 }
76 });
77
78 stream::iter(futures).buffer_unordered(limit).collect().await
79 }
80
81 async fn execute_single(
82 tool: Arc<dyn AgentTool>,
83 tool_use: &ToolUse,
84 invocation_state: &InvocationState,
85 ) -> Vec<ToolEvent> {
86 let context = ToolContext::with_state(invocation_state.clone());
87 let result = match tool.invoke(tool_use.input.clone(), &context).await {
88 Ok(r) => ToolResult {
89 tool_use_id: tool_use.tool_use_id.clone(),
90 status: r.status,
91 content: r.content,
92 },
93 Err(e) => ToolResult {
94 tool_use_id: tool_use.tool_use_id.clone(),
95 status: ToolResultStatus::Error,
96 content: vec![ToolResultContent::text(e)],
97 },
98 };
99 vec![ToolEvent::Result(result)]
100 }
101
102 pub async fn execute_one(
104 &self,
105 tool: Arc<dyn AgentTool>,
106 tool_use: &ToolUse,
107 invocation_state: &InvocationState,
108 ) -> Vec<ToolEvent> {
109 Self::execute_single(tool, tool_use, invocation_state).await
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use async_trait::async_trait;
117 use crate::tools::ToolResult2;
118 use crate::types::tools::ToolSpec;
119
120 struct SlowTool { name: String, delay_ms: u64 }
121
122 #[async_trait]
123 impl AgentTool for SlowTool {
124 fn name(&self) -> &str { &self.name }
125 fn description(&self) -> &str { "A slow tool" }
126 fn tool_spec(&self) -> ToolSpec { ToolSpec::new(&self.name, "A slow tool") }
127
128 async fn invoke(
129 &self,
130 _input: serde_json::Value,
131 _context: &ToolContext,
132 ) -> std::result::Result<ToolResult2, String> {
133 tokio::time::sleep(tokio::time::Duration::from_millis(self.delay_ms)).await;
134 Ok(ToolResult2::success(format!("done after {}ms", self.delay_ms)))
135 }
136 }
137
138 #[tokio::test]
139 async fn test_sequential_execution() {
140 let executor = ToolExecutor::sequential();
141 let tool1: Arc<dyn AgentTool> = Arc::new(SlowTool { name: "tool1".to_string(), delay_ms: 10 });
142 let tool2: Arc<dyn AgentTool> = Arc::new(SlowTool { name: "tool2".to_string(), delay_ms: 10 });
143
144 let tools = vec![
145 (tool1, ToolUse::new("tool1", "1", serde_json::json!({}))),
146 (tool2, ToolUse::new("tool2", "2", serde_json::json!({}))),
147 ];
148
149 let state = InvocationState::new();
150 let results = executor.execute_all(&tools, &state).await;
151 assert_eq!(results.len(), 2);
152 }
153
154 #[tokio::test]
155 async fn test_concurrent_execution() {
156 let executor = ToolExecutor::concurrent(None);
157 let tool1: Arc<dyn AgentTool> = Arc::new(SlowTool { name: "tool1".to_string(), delay_ms: 50 });
158 let tool2: Arc<dyn AgentTool> = Arc::new(SlowTool { name: "tool2".to_string(), delay_ms: 50 });
159
160 let tools = vec![
161 (tool1, ToolUse::new("tool1", "1", serde_json::json!({}))),
162 (tool2, ToolUse::new("tool2", "2", serde_json::json!({}))),
163 ];
164
165 let state = InvocationState::new();
166 let start = std::time::Instant::now();
167 let results = executor.execute_all(&tools, &state).await;
168 let elapsed = start.elapsed();
169
170 assert_eq!(results.len(), 2);
171 assert!(elapsed.as_millis() < 100);
172 }
173}