strands_agents/types/
tools.rs

1//! Tool-related types for agent tools.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// JSON Schema type alias.
7pub type JsonSchema = serde_json::Value;
8
9/// Specification for an agent tool.
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11#[serde(rename_all = "camelCase")]
12pub struct ToolSpec {
13    pub name: String,
14    pub description: String,
15    pub input_schema: InputSchema,
16
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub output_schema: Option<JsonSchema>,
19}
20
21impl ToolSpec {
22    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
23        Self {
24            name: name.into(),
25            description: description.into(),
26            input_schema: InputSchema::default(),
27            output_schema: None,
28        }
29    }
30
31    pub fn with_input_schema(mut self, schema: JsonSchema) -> Self {
32        self.input_schema = InputSchema { json: schema };
33        self
34    }
35
36    pub fn with_output_schema(mut self, schema: JsonSchema) -> Self {
37        self.output_schema = Some(schema);
38        self
39    }
40}
41
42/// Input schema for a tool.
43#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
44pub struct InputSchema {
45    pub json: JsonSchema,
46}
47
48/// A tool configuration containing its specification.
49#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
50#[serde(rename_all = "camelCase")]
51pub struct Tool {
52    pub tool_spec: ToolSpec,
53}
54
55/// A request to use a tool.
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct ToolUse {
59    pub name: String,
60    pub tool_use_id: String,
61    pub input: serde_json::Value,
62}
63
64impl ToolUse {
65    pub fn new(name: impl Into<String>, tool_use_id: impl Into<String>, input: serde_json::Value) -> Self {
66        Self { name: name.into(), tool_use_id: tool_use_id.into(), input }
67    }
68
69    pub fn get_param<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
70        self.input.get(key).and_then(|v| T::deserialize(v).ok())
71    }
72}
73
74/// Content returned from a tool execution.
75#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
76#[serde(rename_all = "camelCase")]
77pub struct ToolResultContent {
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub text: Option<String>,
80
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub json: Option<serde_json::Value>,
83
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub image: Option<ImageResultContent>,
86
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub document: Option<DocumentResultContent>,
89}
90
91impl ToolResultContent {
92    pub fn text(text: impl Into<String>) -> Self {
93        Self { text: Some(text.into()), ..Default::default() }
94    }
95
96    pub fn json(value: serde_json::Value) -> Self {
97        Self { json: Some(value), ..Default::default() }
98    }
99}
100
101/// Image content in a tool result.
102#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct ImageResultContent {
105    pub format: String,
106    pub data: String,
107}
108
109/// Document content in a tool result.
110#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
111#[serde(rename_all = "camelCase")]
112pub struct DocumentResultContent {
113    pub format: String,
114    pub name: String,
115    pub data: String,
116}
117
118/// Status of a tool execution.
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
120#[serde(rename_all = "lowercase")]
121pub enum ToolResultStatus {
122    Success,
123    Error,
124}
125
126impl ToolResultStatus {
127    /// Returns the string representation of the status.
128    pub fn as_str(&self) -> &'static str {
129        match self {
130            ToolResultStatus::Success => "success",
131            ToolResultStatus::Error => "error",
132        }
133    }
134}
135
136impl std::fmt::Display for ToolResultStatus {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(f, "{}", self.as_str())
139    }
140}
141
142/// Result from a tool execution.
143#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
144#[serde(rename_all = "camelCase")]
145pub struct ToolResult {
146    pub tool_use_id: String,
147    pub status: ToolResultStatus,
148    pub content: Vec<ToolResultContent>,
149}
150
151impl ToolResult {
152    pub fn success(tool_use_id: impl Into<String>, text: impl Into<String>) -> Self {
153        Self {
154            tool_use_id: tool_use_id.into(),
155            status: ToolResultStatus::Success,
156            content: vec![ToolResultContent::text(text)],
157        }
158    }
159
160    pub fn success_json(tool_use_id: impl Into<String>, json: serde_json::Value) -> Self {
161        Self {
162            tool_use_id: tool_use_id.into(),
163            status: ToolResultStatus::Success,
164            content: vec![ToolResultContent::json(json)],
165        }
166    }
167
168    pub fn error(tool_use_id: impl Into<String>, error_message: impl Into<String>) -> Self {
169        Self {
170            tool_use_id: tool_use_id.into(),
171            status: ToolResultStatus::Error,
172            content: vec![ToolResultContent::text(error_message)],
173        }
174    }
175
176    pub fn is_success(&self) -> bool { self.status == ToolResultStatus::Success }
177    pub fn is_error(&self) -> bool { self.status == ToolResultStatus::Error }
178}
179
180/// Auto tool choice - model decides whether to use tools.
181#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
182pub struct ToolChoiceAuto {}
183
184/// Any tool choice - model must use at least one tool.
185#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
186pub struct ToolChoiceAny {}
187
188/// Specific tool choice - model must use the named tool.
189#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
190pub struct ToolChoiceTool {
191    pub name: String,
192}
193
194/// Configuration for how the model should use tools.
195#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
196#[serde(rename_all = "lowercase")]
197pub enum ToolChoice {
198    Auto(ToolChoiceAuto),
199    Any(ToolChoiceAny),
200    Tool(ToolChoiceTool),
201}
202
203impl Default for ToolChoice {
204    fn default() -> Self { Self::Auto(ToolChoiceAuto {}) }
205}
206
207impl ToolChoice {
208    pub fn auto() -> Self { Self::Auto(ToolChoiceAuto {}) }
209    pub fn any() -> Self { Self::Any(ToolChoiceAny {}) }
210    pub fn tool(name: impl Into<String>) -> Self { Self::Tool(ToolChoiceTool { name: name.into() }) }
211}
212
213/// Tool configuration for a model request.
214#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
215#[serde(rename_all = "camelCase")]
216pub struct ToolConfig {
217    pub tools: Vec<Tool>,
218
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub tool_choice: Option<ToolChoice>,
221}
222
223/// Context provided to a tool during execution.
224#[derive(Debug, Clone)]
225pub struct ToolContext {
226    pub tool_use: ToolUse,
227    pub invocation_state: HashMap<String, serde_json::Value>,
228}
229
230impl ToolContext {
231    pub fn new(tool_use: ToolUse) -> Self {
232        Self { tool_use, invocation_state: HashMap::new() }
233    }
234
235    pub fn with_state(tool_use: ToolUse, state: HashMap<String, serde_json::Value>) -> Self {
236        Self { tool_use, invocation_state: state }
237    }
238
239    pub fn get_state<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
240        self.invocation_state.get(key).and_then(|v| T::deserialize(v).ok())
241    }
242
243    pub fn interrupt_id(&self, name: &str) -> String {
244        format!(
245            "v1:tool_call:{}:{}",
246            self.tool_use.tool_use_id,
247            uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_OID, name.as_bytes())
248        )
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_tool_spec_creation() {
258        let spec = ToolSpec::new("get_weather", "Get weather for a location");
259        assert_eq!(spec.name, "get_weather");
260        assert_eq!(spec.description, "Get weather for a location");
261    }
262
263    #[test]
264    fn test_tool_result_success() {
265        let result = ToolResult::success("123", "Weather is sunny");
266        assert!(result.is_success());
267        assert!(!result.is_error());
268    }
269
270    #[test]
271    fn test_tool_result_error() {
272        let result = ToolResult::error("123", "Failed to fetch weather");
273        assert!(result.is_error());
274        assert!(!result.is_success());
275    }
276
277    #[test]
278    fn test_tool_choice_variants() {
279        let auto = ToolChoice::auto();
280        assert!(matches!(auto, ToolChoice::Auto(_)));
281
282        let any = ToolChoice::any();
283        assert!(matches!(any, ToolChoice::Any(_)));
284
285        let specific = ToolChoice::tool("my_tool");
286        assert!(matches!(specific, ToolChoice::Tool(t) if t.name == "my_tool"));
287    }
288
289    #[test]
290    fn test_tool_result_content_serialization() {
291        let content = ToolResultContent::text("hello");
292        let json = serde_json::to_string(&content).unwrap();
293        assert_eq!(json, r#"{"text":"hello"}"#);
294    }
295
296    #[test]
297    fn test_tool_choice_serialization() {
298        let auto = ToolChoice::auto();
299        let json = serde_json::to_string(&auto).unwrap();
300        assert_eq!(json, r#"{"auto":{}}"#);
301
302        let any = ToolChoice::any();
303        let json = serde_json::to_string(&any).unwrap();
304        assert_eq!(json, r#"{"any":{}}"#);
305
306        let tool = ToolChoice::tool("my_tool");
307        let json = serde_json::to_string(&tool).unwrap();
308        assert_eq!(json, r#"{"tool":{"name":"my_tool"}}"#);
309    }
310}