Skip to main content

traitclaw_core/traits/
tool.rs

1//! Tool trait for agent capabilities.
2//!
3//! The [`Tool`] trait defines a capability that an agent can use.
4//! Tools have typed inputs and outputs with automatic JSON Schema generation.
5
6use async_trait::async_trait;
7use schemars::JsonSchema;
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::Result;
12
13/// JSON Schema representation for a tool's parameters.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ToolSchema {
16    /// Tool name.
17    pub name: String,
18    /// Tool description.
19    pub description: String,
20    /// JSON Schema for the tool's input parameters.
21    pub parameters: Value,
22}
23
24/// Trait for defining agent tools with typed inputs and outputs.
25///
26/// Tools are the primary way agents interact with the outside world.
27/// Each tool has a typed `Input` (auto-generates JSON Schema) and `Output`.
28///
29/// # Example
30///
31/// ```rust,no_run
32/// use async_trait::async_trait;
33/// use traitclaw_core::prelude::*;
34/// use schemars::JsonSchema;
35/// use serde::{Deserialize, Serialize};
36///
37/// #[derive(Deserialize, JsonSchema)]
38/// struct SearchInput {
39///     query: String,
40/// }
41///
42/// #[derive(Serialize)]
43/// struct SearchOutput {
44///     results: Vec<String>,
45/// }
46///
47/// struct WebSearch;
48///
49/// #[async_trait]
50/// impl Tool for WebSearch {
51///     type Input = SearchInput;
52///     type Output = SearchOutput;
53///
54///     fn name(&self) -> &str { "web_search" }
55///     fn description(&self) -> &str { "Search the web" }
56///
57///     async fn execute(&self, input: Self::Input) -> traitclaw_core::Result<Self::Output> {
58///         Ok(SearchOutput { results: vec![format!("Result for: {}", input.query)] })
59///     }
60/// }
61/// ```
62#[async_trait]
63pub trait Tool: Send + Sync + 'static {
64    /// Input type — must be deserializable from JSON and have a JSON Schema.
65    type Input: DeserializeOwned + JsonSchema + Send;
66    /// Output type — must be serializable to JSON.
67    type Output: Serialize + Send;
68
69    /// The unique name of this tool.
70    fn name(&self) -> &str;
71
72    /// A description of what this tool does (sent to the LLM).
73    fn description(&self) -> &str;
74
75    /// Generate the JSON Schema for this tool's parameters.
76    fn schema(&self) -> ToolSchema {
77        let schema = schemars::schema_for!(Self::Input);
78        ToolSchema {
79            name: self.name().to_string(),
80            description: self.description().to_string(),
81            parameters: serde_json::to_value(schema).unwrap_or_default(),
82        }
83    }
84
85    /// Execute this tool with the given input.
86    async fn execute(&self, input: Self::Input) -> Result<Self::Output>;
87}
88
89/// Type-erased tool wrapper for dynamic dispatch.
90///
91/// Allows storing heterogeneous tools in `Vec<Arc<dyn ErasedTool>>`.
92#[async_trait]
93pub trait ErasedTool: Send + Sync + 'static {
94    /// The unique name of this tool.
95    fn name(&self) -> &str;
96
97    /// A description of what this tool does.
98    fn description(&self) -> &str;
99
100    /// Get the JSON Schema for this tool.
101    fn schema(&self) -> ToolSchema;
102
103    /// Execute this tool with JSON input, returning JSON output.
104    async fn execute_json(&self, input: Value) -> Result<Value>;
105}
106
107#[async_trait]
108impl<T: Tool> ErasedTool for T {
109    fn name(&self) -> &str {
110        Tool::name(self)
111    }
112
113    fn description(&self) -> &str {
114        Tool::description(self)
115    }
116
117    fn schema(&self) -> ToolSchema {
118        Tool::schema(self)
119    }
120
121    async fn execute_json(&self, input: Value) -> Result<Value> {
122        let typed_input: T::Input = serde_json::from_value(input).map_err(|e| {
123            crate::Error::tool_execution(self.name(), format!("Invalid input: {e}"))
124        })?;
125
126        let output = self.execute(typed_input).await?;
127
128        serde_json::to_value(output).map_err(|e| {
129            crate::Error::tool_execution(self.name(), format!("Failed to serialize output: {e}"))
130        })
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    // ── Test fixtures ──────────────────────────────────────────────
139
140    #[derive(Deserialize, JsonSchema)]
141    struct AddInput {
142        a: i64,
143        b: i64,
144    }
145
146    #[derive(Serialize, Deserialize, PartialEq, Debug)]
147    struct AddOutput {
148        sum: i64,
149    }
150
151    struct AddTool;
152
153    #[async_trait]
154    #[allow(clippy::unnecessary_literal_bound)]
155    impl Tool for AddTool {
156        type Input = AddInput;
157        type Output = AddOutput;
158
159        fn name(&self) -> &str {
160            "add"
161        }
162        fn description(&self) -> &str {
163            "Add two numbers"
164        }
165
166        async fn execute(&self, input: Self::Input) -> Result<Self::Output> {
167            Ok(AddOutput {
168                sum: input.a + input.b,
169            })
170        }
171    }
172
173    // ── AC 1+2: Manual Tool implementation with typed I/O ──────
174
175    #[tokio::test]
176    async fn test_tool_execute_typed() {
177        let tool = AddTool;
178        let result = tool.execute(AddInput { a: 3, b: 4 }).await.unwrap();
179        assert_eq!(result.sum, 7);
180    }
181
182    #[test]
183    fn test_tool_name_and_description() {
184        let tool = AddTool;
185        assert_eq!(Tool::name(&tool), "add");
186        assert_eq!(Tool::description(&tool), "Add two numbers");
187    }
188
189    // ── AC 1: Schema generation from schemars ──────────────────
190
191    #[test]
192    fn test_schema_generation() {
193        let tool = AddTool;
194        let schema = Tool::schema(&tool);
195
196        assert_eq!(schema.name, "add");
197        assert_eq!(schema.description, "Add two numbers");
198
199        // Schema must contain properties for a and b
200        let params = &schema.parameters;
201        let props = params
202            .get("properties")
203            .expect("schema should have properties");
204        assert!(props.get("a").is_some(), "schema missing 'a' property");
205        assert!(props.get("b").is_some(), "schema missing 'b' property");
206    }
207
208    #[test]
209    fn test_tool_schema_serializes_to_json() {
210        let tool = AddTool;
211        let schema = Tool::schema(&tool);
212
213        // ToolSchema must serialize (for OpenAI tools parameter format)
214        let json = serde_json::to_value(&schema).unwrap();
215        assert_eq!(json["name"], "add");
216        assert_eq!(json["description"], "Add two numbers");
217        assert!(json["parameters"].is_object());
218    }
219
220    // ── AC 3: ErasedTool enables heterogeneous storage ─────────
221
222    #[test]
223    fn test_erased_tool_in_vec() {
224        let tools: Vec<std::sync::Arc<dyn ErasedTool>> = vec![std::sync::Arc::new(AddTool)];
225
226        assert_eq!(tools.len(), 1);
227        assert_eq!(tools[0].name(), "add");
228        assert_eq!(tools[0].description(), "Add two numbers");
229    }
230
231    // ── AC 4: ErasedTool JSON round-trip ───────────────────────
232
233    #[tokio::test]
234    async fn test_erased_tool_json_round_trip() {
235        let tool: std::sync::Arc<dyn ErasedTool> = std::sync::Arc::new(AddTool);
236
237        let input = serde_json::json!({"a": 10, "b": 20});
238        let output = tool.execute_json(input).await.unwrap();
239
240        let result: AddOutput = serde_json::from_value(output).unwrap();
241        assert_eq!(result.sum, 30);
242    }
243
244    #[tokio::test]
245    async fn test_erased_tool_invalid_input_returns_error() {
246        let tool: std::sync::Arc<dyn ErasedTool> = std::sync::Arc::new(AddTool);
247
248        let bad_input = serde_json::json!({"x": "not a number"});
249        let result = tool.execute_json(bad_input).await;
250
251        assert!(result.is_err());
252        let err = result.unwrap_err();
253        assert!(
254            err.to_string().contains("add"),
255            "error should mention tool name"
256        );
257        assert!(
258            err.to_string().contains("Invalid input"),
259            "error should say invalid input"
260        );
261    }
262
263    #[test]
264    fn test_erased_tool_schema_matches_tool_schema() {
265        let tool = AddTool;
266        let direct_schema = Tool::schema(&tool);
267        let erased_schema = ErasedTool::schema(&tool);
268
269        assert_eq!(direct_schema.name, erased_schema.name);
270        assert_eq!(direct_schema.description, erased_schema.description);
271        assert_eq!(direct_schema.parameters, erased_schema.parameters);
272    }
273}