strands_agents/tools/
mod.rs

1//! Tool definitions and execution.
2
3pub mod executor;
4pub mod helpers;
5pub mod loader;
6pub mod mcp;
7pub mod mcp_instrumentation;
8pub mod registry;
9pub mod structured_output;
10pub mod validator;
11pub mod watcher;
12
13use std::pin::Pin;
14
15use futures::Stream;
16
17use async_trait::async_trait;
18
19use crate::types::tools::{ToolResult, ToolResultContent, ToolResultStatus, ToolSpec, ToolUse};
20
21/// A stream of tool execution events.
22pub type ToolEventStream = Pin<Box<dyn Stream<Item = ToolEvent> + Send>>;
23
24/// Generator type for tool execution streams.
25pub type ToolGenerator = ToolEventStream;
26
27/// Events emitted during tool execution.
28#[derive(Debug, Clone)]
29pub enum ToolEvent {
30    /// Progress update during execution.
31    Progress { message: String },
32    /// Streaming data from the tool.
33    Stream(serde_json::Value),
34    /// Final result of tool execution.
35    Result(ToolResult),
36    /// Interrupt request from the tool.
37    Interrupt { id: String, data: serde_json::Value },
38}
39
40impl ToolEvent {
41    pub fn progress(message: impl Into<String>) -> Self {
42        Self::Progress { message: message.into() }
43    }
44
45    pub fn stream(data: serde_json::Value) -> Self { Self::Stream(data) }
46    pub fn result(result: ToolResult) -> Self { Self::Result(result) }
47    pub fn is_result(&self) -> bool { matches!(self, Self::Result(_)) }
48
49    pub fn as_result(&self) -> Option<&ToolResult> {
50        match self {
51            Self::Result(r) => Some(r),
52            _ => None,
53        }
54    }
55}
56
57/// State passed through tool invocations.
58#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
59pub struct InvocationState {
60    pub data: std::collections::HashMap<String, serde_json::Value>,
61    #[serde(default)]
62    pub stop_event_loop: bool,
63}
64
65impl InvocationState {
66    pub fn new() -> Self { Self::default() }
67
68    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
69        self.data.get(key).and_then(|v| T::deserialize(v).ok())
70    }
71
72    pub fn set(&mut self, key: impl Into<String>, value: impl serde::Serialize) {
73        if let Ok(v) = serde_json::to_value(value) {
74            self.data.insert(key.into(), v);
75        }
76    }
77}
78
79/// Context provided to tool execution.
80#[derive(Debug, Clone, Default)]
81pub struct ToolContext {
82    pub invocation_state: InvocationState,
83    pub interrupt_id: Option<uuid::Uuid>,
84}
85
86impl ToolContext {
87    pub fn new() -> Self { Self::default() }
88
89    pub fn with_state(state: InvocationState) -> Self {
90        Self { invocation_state: state, interrupt_id: None }
91    }
92}
93
94/// Result returned from async tool invocation.
95#[derive(Debug, Clone)]
96pub struct ToolResult2 {
97    pub status: ToolResultStatus,
98    pub content: Vec<ToolResultContent>,
99}
100
101impl ToolResult2 {
102    pub fn success(content: impl Into<String>) -> Self {
103        Self {
104            status: ToolResultStatus::Success,
105            content: vec![ToolResultContent::text(content.into())],
106        }
107    }
108
109    pub fn success_json(value: serde_json::Value) -> Self {
110        Self {
111            status: ToolResultStatus::Success,
112            content: vec![ToolResultContent::json(value)],
113        }
114    }
115
116    pub fn error(message: impl Into<String>) -> Self {
117        Self {
118            status: ToolResultStatus::Error,
119            content: vec![ToolResultContent::text(message.into())],
120        }
121    }
122}
123
124/// Trait for implementing agent tools.
125#[async_trait]
126pub trait AgentTool: Send + Sync {
127    /// Returns the unique name of the tool.
128    fn name(&self) -> &str;
129
130    /// Returns the tool description.
131    fn description(&self) -> &str;
132
133    /// Returns the tool specification.
134    fn tool_spec(&self) -> ToolSpec;
135
136    /// Invokes the tool asynchronously.
137    async fn invoke(
138        &self,
139        input: serde_json::Value,
140        context: &ToolContext,
141    ) -> std::result::Result<ToolResult2, String>;
142
143    /// Legacy name accessor.
144    fn tool_name(&self) -> &str { self.name() }
145
146    /// Returns the type of the tool (e.g., "function", "python").
147    fn tool_type(&self) -> &str { "function" }
148
149    /// Whether the tool supports hot reloading.
150    fn supports_hot_reload(&self) -> bool { false }
151
152    /// Whether this is a dynamically loaded tool.
153    fn is_dynamic(&self) -> bool { false }
154
155    /// Returns display properties for the tool.
156    fn get_display_properties(&self) -> std::collections::HashMap<String, String> {
157        let mut props = std::collections::HashMap::new();
158        props.insert("Name".to_string(), self.name().to_string());
159        props.insert("Type".to_string(), self.tool_type().to_string());
160        props
161    }
162}
163
164/// Executes an agent tool and returns a stream of events.
165pub fn tool_to_stream(
166    tool: std::sync::Arc<dyn AgentTool>,
167    tool_use: ToolUse,
168    invocation_state: InvocationState,
169) -> ToolGenerator {
170    let input = tool_use.input.clone();
171    let tool_use_id = tool_use.tool_use_id.clone();
172    let context = ToolContext::with_state(invocation_state);
173
174    Box::pin(async_stream::stream! {
175        let result = match tool.invoke(input, &context).await {
176            Ok(r) => ToolResult {
177                tool_use_id,
178                status: r.status,
179                content: r.content,
180            },
181            Err(e) => ToolResult {
182                tool_use_id,
183                status: ToolResultStatus::Error,
184                content: vec![ToolResultContent::text(e)],
185            },
186        };
187        yield ToolEvent::Result(result);
188    })
189}
190
191/// Trait for dynamically loaded tools.
192pub trait DynamicAgentTool: AgentTool {
193    /// Marks the tool as dynamic.
194    fn mark_dynamic(&mut self);
195}
196
197/// Executes a tool and returns its event stream.
198pub fn execute_tool_stream(
199    tool: std::sync::Arc<dyn AgentTool>,
200    tool_use: ToolUse,
201    invocation_state: InvocationState,
202) -> ToolGenerator {
203    tool_to_stream(tool, tool_use, invocation_state)
204}
205
206pub use loader::{ReloadCallback, ToolLoader, ToolLoaderConfig, ToolWatcher};
207pub use mcp::{
208    ConnectionState, MCPClient, MCPImageContent, MCPImageSource, MCPResultContent,
209    MCPServerConfig, MCPToolResult, MCPToolSpec, MCPTransport, ToolFilters, ToolProvider,
210};
211pub use registry::{ToolInput, ToolRegistry};
212pub use structured_output::{
213    flatten_schema, get_required_fields, process_schema_for_optional_fields, schema_to_tool_spec,
214    structured_output_spec, validate_against_schema, StructuredOutputContext, StructuredOutputResult,
215    StructuredOutputTool,
216};
217pub use helpers::{
218    generate_cancelled_tool_result, generate_missing_tool_result,
219    generate_missing_tool_result_content, generate_missing_tool_results_for_message,
220    generate_timeout_tool_result, noop_tool, noop_tool_with, NoopTool,
221};
222pub use validator::{
223    is_valid_tool_name, sanitize_tool_name, validate_and_prepare_tools, validate_tool_spec,
224    validate_tool_specs, validate_tool_use, validate_tool_uses, ToolUseValidationResult,
225    MAX_TOOL_NAME_LENGTH, MIN_TOOL_NAME_LENGTH,
226};
227pub use mcp_instrumentation::{
228    create_mcp_tool_span, extract_trace_context, init_mcp_instrumentation, inject_trace_context,
229    is_instrumentation_applied, ExtractableContext, InjectableContext, ItemWithContext,
230    MCPInstrumentationConfig, InstrumentationGuard,
231};
232
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use std::sync::Arc;
238
239    struct TestTool;
240
241    #[async_trait]
242    impl AgentTool for TestTool {
243        fn name(&self) -> &str { "test_tool" }
244        fn description(&self) -> &str { "A test tool" }
245        fn tool_spec(&self) -> ToolSpec { ToolSpec::new("test_tool", "A test tool") }
246
247        async fn invoke(
248            &self,
249            _input: serde_json::Value,
250            _context: &ToolContext,
251        ) -> std::result::Result<ToolResult2, String> {
252            Ok(ToolResult2::success("Test result"))
253        }
254    }
255
256    #[tokio::test]
257    async fn test_tool_execution() {
258        use futures::StreamExt;
259
260        let tool: Arc<dyn AgentTool> = Arc::new(TestTool);
261        let tool_use = ToolUse::new("test_tool", "123", serde_json::json!({}));
262        let state = InvocationState::new();
263        let mut stream = tool_to_stream(tool, tool_use, state);
264
265        if let Some(event) = stream.next().await {
266            assert!(event.is_result());
267            assert!(event.as_result().unwrap().is_success());
268        }
269    }
270}