traitclaw_core/traits/
tool.rs1use async_trait::async_trait;
7use schemars::JsonSchema;
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::Result;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ToolSchema {
16 pub name: String,
18 pub description: String,
20 pub parameters: Value,
22}
23
24#[async_trait]
63pub trait Tool: Send + Sync + 'static {
64 type Input: DeserializeOwned + JsonSchema + Send;
66 type Output: Serialize + Send;
68
69 fn name(&self) -> &str;
71
72 fn description(&self) -> &str;
74
75 fn schema(&self) -> ToolSchema {
77 let schema = schemars::schema_for!(Self::Input);
78 ToolSchema {
79 name: self.name().to_string(),
80 description: self.description().to_string(),
81 parameters: serde_json::to_value(schema).unwrap_or_default(),
82 }
83 }
84
85 async fn execute(&self, input: Self::Input) -> Result<Self::Output>;
87}
88
89#[async_trait]
93pub trait ErasedTool: Send + Sync + 'static {
94 fn name(&self) -> &str;
96
97 fn description(&self) -> &str;
99
100 fn schema(&self) -> ToolSchema;
102
103 async fn execute_json(&self, input: Value) -> Result<Value>;
105}
106
107#[async_trait]
108impl<T: Tool> ErasedTool for T {
109 fn name(&self) -> &str {
110 Tool::name(self)
111 }
112
113 fn description(&self) -> &str {
114 Tool::description(self)
115 }
116
117 fn schema(&self) -> ToolSchema {
118 Tool::schema(self)
119 }
120
121 async fn execute_json(&self, input: Value) -> Result<Value> {
122 let typed_input: T::Input = serde_json::from_value(input).map_err(|e| {
123 crate::Error::tool_execution(self.name(), format!("Invalid input: {e}"))
124 })?;
125
126 let output = self.execute(typed_input).await?;
127
128 serde_json::to_value(output).map_err(|e| {
129 crate::Error::tool_execution(self.name(), format!("Failed to serialize output: {e}"))
130 })
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[derive(Deserialize, JsonSchema)]
141 struct AddInput {
142 a: i64,
143 b: i64,
144 }
145
146 #[derive(Serialize, Deserialize, PartialEq, Debug)]
147 struct AddOutput {
148 sum: i64,
149 }
150
151 struct AddTool;
152
153 #[async_trait]
154 #[allow(clippy::unnecessary_literal_bound)]
155 impl Tool for AddTool {
156 type Input = AddInput;
157 type Output = AddOutput;
158
159 fn name(&self) -> &str {
160 "add"
161 }
162 fn description(&self) -> &str {
163 "Add two numbers"
164 }
165
166 async fn execute(&self, input: Self::Input) -> Result<Self::Output> {
167 Ok(AddOutput {
168 sum: input.a + input.b,
169 })
170 }
171 }
172
173 #[tokio::test]
176 async fn test_tool_execute_typed() {
177 let tool = AddTool;
178 let result = tool.execute(AddInput { a: 3, b: 4 }).await.unwrap();
179 assert_eq!(result.sum, 7);
180 }
181
182 #[test]
183 fn test_tool_name_and_description() {
184 let tool = AddTool;
185 assert_eq!(Tool::name(&tool), "add");
186 assert_eq!(Tool::description(&tool), "Add two numbers");
187 }
188
189 #[test]
192 fn test_schema_generation() {
193 let tool = AddTool;
194 let schema = Tool::schema(&tool);
195
196 assert_eq!(schema.name, "add");
197 assert_eq!(schema.description, "Add two numbers");
198
199 let params = &schema.parameters;
201 let props = params
202 .get("properties")
203 .expect("schema should have properties");
204 assert!(props.get("a").is_some(), "schema missing 'a' property");
205 assert!(props.get("b").is_some(), "schema missing 'b' property");
206 }
207
208 #[test]
209 fn test_tool_schema_serializes_to_json() {
210 let tool = AddTool;
211 let schema = Tool::schema(&tool);
212
213 let json = serde_json::to_value(&schema).unwrap();
215 assert_eq!(json["name"], "add");
216 assert_eq!(json["description"], "Add two numbers");
217 assert!(json["parameters"].is_object());
218 }
219
220 #[test]
223 fn test_erased_tool_in_vec() {
224 let tools: Vec<std::sync::Arc<dyn ErasedTool>> = vec![std::sync::Arc::new(AddTool)];
225
226 assert_eq!(tools.len(), 1);
227 assert_eq!(tools[0].name(), "add");
228 assert_eq!(tools[0].description(), "Add two numbers");
229 }
230
231 #[tokio::test]
234 async fn test_erased_tool_json_round_trip() {
235 let tool: std::sync::Arc<dyn ErasedTool> = std::sync::Arc::new(AddTool);
236
237 let input = serde_json::json!({"a": 10, "b": 20});
238 let output = tool.execute_json(input).await.unwrap();
239
240 let result: AddOutput = serde_json::from_value(output).unwrap();
241 assert_eq!(result.sum, 30);
242 }
243
244 #[tokio::test]
245 async fn test_erased_tool_invalid_input_returns_error() {
246 let tool: std::sync::Arc<dyn ErasedTool> = std::sync::Arc::new(AddTool);
247
248 let bad_input = serde_json::json!({"x": "not a number"});
249 let result = tool.execute_json(bad_input).await;
250
251 assert!(result.is_err());
252 let err = result.unwrap_err();
253 assert!(
254 err.to_string().contains("add"),
255 "error should mention tool name"
256 );
257 assert!(
258 err.to_string().contains("Invalid input"),
259 "error should say invalid input"
260 );
261 }
262
263 #[test]
264 fn test_erased_tool_schema_matches_tool_schema() {
265 let tool = AddTool;
266 let direct_schema = Tool::schema(&tool);
267 let erased_schema = ErasedTool::schema(&tool);
268
269 assert_eq!(direct_schema.name, erased_schema.name);
270 assert_eq!(direct_schema.description, erased_schema.description);
271 assert_eq!(direct_schema.parameters, erased_schema.parameters);
272 }
273}