1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6pub type JsonSchema = serde_json::Value;
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11#[serde(rename_all = "camelCase")]
12pub struct ToolSpec {
13 pub name: String,
14 pub description: String,
15 pub input_schema: InputSchema,
16
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub output_schema: Option<JsonSchema>,
19}
20
21impl ToolSpec {
22 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
23 Self {
24 name: name.into(),
25 description: description.into(),
26 input_schema: InputSchema::default(),
27 output_schema: None,
28 }
29 }
30
31 pub fn with_input_schema(mut self, schema: JsonSchema) -> Self {
32 self.input_schema = InputSchema { json: schema };
33 self
34 }
35
36 pub fn with_output_schema(mut self, schema: JsonSchema) -> Self {
37 self.output_schema = Some(schema);
38 self
39 }
40}
41
42#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
44pub struct InputSchema {
45 pub json: JsonSchema,
46}
47
48#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
50#[serde(rename_all = "camelCase")]
51pub struct Tool {
52 pub tool_spec: ToolSpec,
53}
54
55#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct ToolUse {
59 pub name: String,
60 pub tool_use_id: String,
61 pub input: serde_json::Value,
62}
63
64impl ToolUse {
65 pub fn new(name: impl Into<String>, tool_use_id: impl Into<String>, input: serde_json::Value) -> Self {
66 Self { name: name.into(), tool_use_id: tool_use_id.into(), input }
67 }
68
69 pub fn get_param<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
70 self.input.get(key).and_then(|v| T::deserialize(v).ok())
71 }
72}
73
74#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
76#[serde(rename_all = "camelCase")]
77pub struct ToolResultContent {
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub text: Option<String>,
80
81 #[serde(skip_serializing_if = "Option::is_none")]
82 pub json: Option<serde_json::Value>,
83
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub image: Option<ImageResultContent>,
86
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub document: Option<DocumentResultContent>,
89}
90
91impl ToolResultContent {
92 pub fn text(text: impl Into<String>) -> Self {
93 Self { text: Some(text.into()), ..Default::default() }
94 }
95
96 pub fn json(value: serde_json::Value) -> Self {
97 Self { json: Some(value), ..Default::default() }
98 }
99}
100
101#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
103#[serde(rename_all = "camelCase")]
104pub struct ImageResultContent {
105 pub format: String,
106 pub data: String,
107}
108
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
111#[serde(rename_all = "camelCase")]
112pub struct DocumentResultContent {
113 pub format: String,
114 pub name: String,
115 pub data: String,
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
120#[serde(rename_all = "lowercase")]
121pub enum ToolResultStatus {
122 Success,
123 Error,
124}
125
126impl ToolResultStatus {
127 pub fn as_str(&self) -> &'static str {
129 match self {
130 ToolResultStatus::Success => "success",
131 ToolResultStatus::Error => "error",
132 }
133 }
134}
135
136impl std::fmt::Display for ToolResultStatus {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 write!(f, "{}", self.as_str())
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
144#[serde(rename_all = "camelCase")]
145pub struct ToolResult {
146 pub tool_use_id: String,
147 pub status: ToolResultStatus,
148 pub content: Vec<ToolResultContent>,
149}
150
151impl ToolResult {
152 pub fn success(tool_use_id: impl Into<String>, text: impl Into<String>) -> Self {
153 Self {
154 tool_use_id: tool_use_id.into(),
155 status: ToolResultStatus::Success,
156 content: vec![ToolResultContent::text(text)],
157 }
158 }
159
160 pub fn success_json(tool_use_id: impl Into<String>, json: serde_json::Value) -> Self {
161 Self {
162 tool_use_id: tool_use_id.into(),
163 status: ToolResultStatus::Success,
164 content: vec![ToolResultContent::json(json)],
165 }
166 }
167
168 pub fn error(tool_use_id: impl Into<String>, error_message: impl Into<String>) -> Self {
169 Self {
170 tool_use_id: tool_use_id.into(),
171 status: ToolResultStatus::Error,
172 content: vec![ToolResultContent::text(error_message)],
173 }
174 }
175
176 pub fn is_success(&self) -> bool { self.status == ToolResultStatus::Success }
177 pub fn is_error(&self) -> bool { self.status == ToolResultStatus::Error }
178}
179
180#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
182pub struct ToolChoiceAuto {}
183
184#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
186pub struct ToolChoiceAny {}
187
188#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
190pub struct ToolChoiceTool {
191 pub name: String,
192}
193
194#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
196#[serde(rename_all = "lowercase")]
197pub enum ToolChoice {
198 Auto(ToolChoiceAuto),
199 Any(ToolChoiceAny),
200 Tool(ToolChoiceTool),
201}
202
203impl Default for ToolChoice {
204 fn default() -> Self { Self::Auto(ToolChoiceAuto {}) }
205}
206
207impl ToolChoice {
208 pub fn auto() -> Self { Self::Auto(ToolChoiceAuto {}) }
209 pub fn any() -> Self { Self::Any(ToolChoiceAny {}) }
210 pub fn tool(name: impl Into<String>) -> Self { Self::Tool(ToolChoiceTool { name: name.into() }) }
211}
212
213#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
215#[serde(rename_all = "camelCase")]
216pub struct ToolConfig {
217 pub tools: Vec<Tool>,
218
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub tool_choice: Option<ToolChoice>,
221}
222
223#[derive(Debug, Clone)]
225pub struct ToolContext {
226 pub tool_use: ToolUse,
227 pub invocation_state: HashMap<String, serde_json::Value>,
228}
229
230impl ToolContext {
231 pub fn new(tool_use: ToolUse) -> Self {
232 Self { tool_use, invocation_state: HashMap::new() }
233 }
234
235 pub fn with_state(tool_use: ToolUse, state: HashMap<String, serde_json::Value>) -> Self {
236 Self { tool_use, invocation_state: state }
237 }
238
239 pub fn get_state<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
240 self.invocation_state.get(key).and_then(|v| T::deserialize(v).ok())
241 }
242
243 pub fn interrupt_id(&self, name: &str) -> String {
244 format!(
245 "v1:tool_call:{}:{}",
246 self.tool_use.tool_use_id,
247 uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_OID, name.as_bytes())
248 )
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_tool_spec_creation() {
258 let spec = ToolSpec::new("get_weather", "Get weather for a location");
259 assert_eq!(spec.name, "get_weather");
260 assert_eq!(spec.description, "Get weather for a location");
261 }
262
263 #[test]
264 fn test_tool_result_success() {
265 let result = ToolResult::success("123", "Weather is sunny");
266 assert!(result.is_success());
267 assert!(!result.is_error());
268 }
269
270 #[test]
271 fn test_tool_result_error() {
272 let result = ToolResult::error("123", "Failed to fetch weather");
273 assert!(result.is_error());
274 assert!(!result.is_success());
275 }
276
277 #[test]
278 fn test_tool_choice_variants() {
279 let auto = ToolChoice::auto();
280 assert!(matches!(auto, ToolChoice::Auto(_)));
281
282 let any = ToolChoice::any();
283 assert!(matches!(any, ToolChoice::Any(_)));
284
285 let specific = ToolChoice::tool("my_tool");
286 assert!(matches!(specific, ToolChoice::Tool(t) if t.name == "my_tool"));
287 }
288
289 #[test]
290 fn test_tool_result_content_serialization() {
291 let content = ToolResultContent::text("hello");
292 let json = serde_json::to_string(&content).unwrap();
293 assert_eq!(json, r#"{"text":"hello"}"#);
294 }
295
296 #[test]
297 fn test_tool_choice_serialization() {
298 let auto = ToolChoice::auto();
299 let json = serde_json::to_string(&auto).unwrap();
300 assert_eq!(json, r#"{"auto":{}}"#);
301
302 let any = ToolChoice::any();
303 let json = serde_json::to_string(&any).unwrap();
304 assert_eq!(json, r#"{"any":{}}"#);
305
306 let tool = ToolChoice::tool("my_tool");
307 let json = serde_json::to_string(&tool).unwrap();
308 assert_eq!(json, r#"{"tool":{"name":"my_tool"}}"#);
309 }
310}