strands_agents/tools/
executor.rs

1//! Tool execution with sequential and concurrent modes.
2
3use std::sync::Arc;
4
5use futures::stream::{self, StreamExt};
6
7use super::{AgentTool, InvocationState, ToolEvent, ToolContext};
8use crate::types::tools::{ToolResult, ToolResultContent, ToolResultStatus, ToolUse};
9
10/// Execution mode for tool processing.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ExecutionMode {
13    Sequential,
14    Concurrent { max_parallel: Option<usize> },
15}
16
17impl Default for ExecutionMode {
18    fn default() -> Self { Self::Concurrent { max_parallel: None } }
19}
20
21/// Executor for running tools.
22pub struct ToolExecutor {
23    mode: ExecutionMode,
24}
25
26impl Default for ToolExecutor {
27    fn default() -> Self { Self::new() }
28}
29
30impl ToolExecutor {
31    pub fn new() -> Self { Self { mode: ExecutionMode::default() } }
32    pub fn sequential() -> Self { Self { mode: ExecutionMode::Sequential } }
33    pub fn concurrent(max_parallel: Option<usize>) -> Self { Self { mode: ExecutionMode::Concurrent { max_parallel } } }
34
35    /// Executes all tools and returns their results.
36    pub async fn execute_all(
37        &self,
38        tools: &[(Arc<dyn AgentTool>, ToolUse)],
39        invocation_state: &InvocationState,
40    ) -> Vec<(String, Vec<ToolEvent>)> {
41        match self.mode {
42            ExecutionMode::Sequential => self.execute_sequential(tools, invocation_state).await,
43            ExecutionMode::Concurrent { max_parallel } => self.execute_concurrent(tools, invocation_state, max_parallel).await,
44        }
45    }
46
47    async fn execute_sequential(
48        &self,
49        tools: &[(Arc<dyn AgentTool>, ToolUse)],
50        invocation_state: &InvocationState,
51    ) -> Vec<(String, Vec<ToolEvent>)> {
52        let mut results = Vec::with_capacity(tools.len());
53        for (tool, tool_use) in tools {
54            let events = Self::execute_single(tool.clone(), tool_use, invocation_state).await;
55            results.push((tool_use.tool_use_id.clone(), events));
56        }
57        results
58    }
59
60    async fn execute_concurrent(
61        &self,
62        tools: &[(Arc<dyn AgentTool>, ToolUse)],
63        invocation_state: &InvocationState,
64        max_parallel: Option<usize>,
65    ) -> Vec<(String, Vec<ToolEvent>)> {
66        let limit = max_parallel.unwrap_or(tools.len());
67
68        let futures = tools.iter().map(|(tool, tool_use)| {
69            let tool = tool.clone();
70            let tool_use = tool_use.clone();
71            let state = invocation_state.clone();
72            async move {
73                let events = Self::execute_single(tool, &tool_use, &state).await;
74                (tool_use.tool_use_id, events)
75            }
76        });
77
78        stream::iter(futures).buffer_unordered(limit).collect().await
79    }
80
81    async fn execute_single(
82        tool: Arc<dyn AgentTool>,
83        tool_use: &ToolUse,
84        invocation_state: &InvocationState,
85    ) -> Vec<ToolEvent> {
86        let context = ToolContext::with_state(invocation_state.clone());
87        let result = match tool.invoke(tool_use.input.clone(), &context).await {
88            Ok(r) => ToolResult {
89                tool_use_id: tool_use.tool_use_id.clone(),
90                status: r.status,
91                content: r.content,
92            },
93            Err(e) => ToolResult {
94                tool_use_id: tool_use.tool_use_id.clone(),
95                status: ToolResultStatus::Error,
96                content: vec![ToolResultContent::text(e)],
97            },
98        };
99        vec![ToolEvent::Result(result)]
100    }
101
102    /// Executes a single tool and returns its events.
103    pub async fn execute_one(
104        &self,
105        tool: Arc<dyn AgentTool>,
106        tool_use: &ToolUse,
107        invocation_state: &InvocationState,
108    ) -> Vec<ToolEvent> {
109        Self::execute_single(tool, tool_use, invocation_state).await
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use async_trait::async_trait;
117    use crate::tools::ToolResult2;
118    use crate::types::tools::ToolSpec;
119
120    struct SlowTool { name: String, delay_ms: u64 }
121
122    #[async_trait]
123    impl AgentTool for SlowTool {
124        fn name(&self) -> &str { &self.name }
125        fn description(&self) -> &str { "A slow tool" }
126        fn tool_spec(&self) -> ToolSpec { ToolSpec::new(&self.name, "A slow tool") }
127
128        async fn invoke(
129            &self,
130            _input: serde_json::Value,
131            _context: &ToolContext,
132        ) -> std::result::Result<ToolResult2, String> {
133            tokio::time::sleep(tokio::time::Duration::from_millis(self.delay_ms)).await;
134            Ok(ToolResult2::success(format!("done after {}ms", self.delay_ms)))
135        }
136    }
137
138    #[tokio::test]
139    async fn test_sequential_execution() {
140        let executor = ToolExecutor::sequential();
141        let tool1: Arc<dyn AgentTool> = Arc::new(SlowTool { name: "tool1".to_string(), delay_ms: 10 });
142        let tool2: Arc<dyn AgentTool> = Arc::new(SlowTool { name: "tool2".to_string(), delay_ms: 10 });
143
144        let tools = vec![
145            (tool1, ToolUse::new("tool1", "1", serde_json::json!({}))),
146            (tool2, ToolUse::new("tool2", "2", serde_json::json!({}))),
147        ];
148
149        let state = InvocationState::new();
150        let results = executor.execute_all(&tools, &state).await;
151        assert_eq!(results.len(), 2);
152    }
153
154    #[tokio::test]
155    async fn test_concurrent_execution() {
156        let executor = ToolExecutor::concurrent(None);
157        let tool1: Arc<dyn AgentTool> = Arc::new(SlowTool { name: "tool1".to_string(), delay_ms: 50 });
158        let tool2: Arc<dyn AgentTool> = Arc::new(SlowTool { name: "tool2".to_string(), delay_ms: 50 });
159
160        let tools = vec![
161            (tool1, ToolUse::new("tool1", "1", serde_json::json!({}))),
162            (tool2, ToolUse::new("tool2", "2", serde_json::json!({}))),
163        ];
164
165        let state = InvocationState::new();
166        let start = std::time::Instant::now();
167        let results = executor.execute_all(&tools, &state).await;
168        let elapsed = start.elapsed();
169
170        assert_eq!(results.len(), 2);
171        assert!(elapsed.as_millis() < 100);
172    }
173}