serdes_ai_tools/
tool.rs

1//! Core tool trait and implementations.
2//!
3//! This module provides the `Tool` trait which all tools must implement,
4//! as well as the `FunctionTool` wrapper for closure-based tools.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use crate::{
14    definition::ToolDefinition, return_types::ToolResult, schema::SchemaBuilder, RunContext,
15};
16
17/// Core trait for all tools.
18///
19/// Implement this trait to create custom tools that can be called by agents.
20/// Tools receive a context with dependencies and arguments as JSON.
21///
22/// # Type Parameters
23///
24/// - `Deps`: The type of dependencies the tool requires. Defaults to `()`.
25///
26/// # Example
27///
28/// ```ignore
29/// use async_trait::async_trait;
30/// use serdes_ai_tools::{Tool, ToolDefinition, RunContext, ToolResult, ToolReturn};
31///
32/// struct GreetTool;
33///
34/// #[async_trait]
35/// impl Tool for GreetTool {
36///     fn definition(&self) -> ToolDefinition {
37///         ToolDefinition::new("greet", "Greet someone")
38///     }
39///
40///     async fn call(&self, _ctx: &RunContext, args: serde_json::Value) -> ToolResult {
41///         let name = args["name"].as_str().unwrap_or("World");
42///         Ok(ToolReturn::text(format!("Hello, {name}!")))
43///     }
44/// }
45/// ```
46#[async_trait]
47pub trait Tool<Deps = ()>: Send + Sync {
48    /// Get the tool's definition.
49    ///
50    /// This provides the name, description, and parameter schema that
51    /// will be sent to the language model.
52    fn definition(&self) -> ToolDefinition;
53
54    /// Execute the tool with given arguments.
55    ///
56    /// # Arguments
57    ///
58    /// - `ctx`: The run context with dependencies and metadata
59    /// - `args`: The arguments as a JSON value
60    ///
61    /// # Returns
62    ///
63    /// A `ToolResult` containing either the return value or an error.
64    async fn call(&self, ctx: &RunContext<Deps>, args: JsonValue) -> ToolResult;
65
66    /// Maximum retries for this tool.
67    ///
68    /// Returns `None` to use the agent's default retry setting.
69    fn max_retries(&self) -> Option<u32> {
70        None
71    }
72
73    /// Prepare the tool definition at runtime.
74    ///
75    /// This allows modifying the tool definition based on the current context,
76    /// for example to add dynamic descriptions or modify parameters.
77    ///
78    /// Return `None` to indicate the tool should be hidden from this run.
79    async fn prepare(
80        &self,
81        _ctx: &RunContext<Deps>,
82        def: ToolDefinition,
83    ) -> Option<ToolDefinition> {
84        Some(def)
85    }
86
87    /// Get the tool name.
88    fn name(&self) -> String {
89        self.definition().name.clone()
90    }
91
92    /// Get the tool description.
93    fn description(&self) -> String {
94        self.definition().description.clone()
95    }
96}
97
98/// Type-erased boxed tool.
99pub type BoxedTool<Deps> = Arc<dyn Tool<Deps>>;
100
101/// Wrapper for function-based tools.
102///
103/// This allows creating tools from async closures without implementing
104/// the `Tool` trait manually.
105///
106/// # Example
107///
108/// ```ignore
109/// use serdes_ai_tools::{FunctionTool, SchemaBuilder, ToolReturn};
110///
111/// let tool = FunctionTool::new(
112///     "add",
113///     "Add two numbers",
114///     SchemaBuilder::new()
115///         .number("a", "First number", true)
116///         .number("b", "Second number", true)
117///         .build()
118///         .unwrap(),
119///     |_ctx, args| async move {
120///         let a = args["a"].as_f64().unwrap_or(0.0);
121///         let b = args["b"].as_f64().unwrap_or(0.0);
122///         Ok(ToolReturn::text(format!("{}", a + b)))
123///     },
124/// );
125/// ```
126pub struct FunctionTool<F, Deps = ()> {
127    name: String,
128    description: String,
129    parameters: JsonValue,
130    function: F,
131    max_retries: Option<u32>,
132    strict: Option<bool>,
133    _phantom: PhantomData<fn() -> Deps>,
134}
135
136impl<F, Deps> FunctionTool<F, Deps> {
137    /// Create a new function tool.
138    pub fn new(
139        name: impl Into<String>,
140        description: impl Into<String>,
141        parameters: impl Into<JsonValue>,
142        function: F,
143    ) -> Self {
144        Self {
145            name: name.into(),
146            description: description.into(),
147            parameters: parameters.into(),
148            function,
149            max_retries: None,
150            strict: None,
151            _phantom: PhantomData,
152        }
153    }
154
155    /// Set maximum retries.
156    #[must_use]
157    pub fn with_max_retries(mut self, retries: u32) -> Self {
158        self.max_retries = Some(retries);
159        self
160    }
161
162    /// Set strict mode.
163    #[must_use]
164    pub fn with_strict(mut self, strict: bool) -> Self {
165        self.strict = Some(strict);
166        self
167    }
168}
169
170// We need a type alias for the pinned future to make the bounds work
171type PinnedFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
172
173#[async_trait]
174impl<F, Deps> Tool<Deps> for FunctionTool<F, Deps>
175where
176    F: for<'a> Fn(&'a RunContext<Deps>, JsonValue) -> PinnedFuture<ToolResult> + Send + Sync,
177    Deps: Send + Sync,
178{
179    fn definition(&self) -> ToolDefinition {
180        let mut def = ToolDefinition::new(&self.name, &self.description)
181            .with_parameters(self.parameters.clone());
182        if let Some(strict) = self.strict {
183            def = def.with_strict(strict);
184        }
185        def
186    }
187
188    async fn call(&self, ctx: &RunContext<Deps>, args: JsonValue) -> ToolResult {
189        (self.function)(ctx, args).await
190    }
191
192    fn max_retries(&self) -> Option<u32> {
193        self.max_retries
194    }
195}
196
197impl<F, Deps> std::fmt::Debug for FunctionTool<F, Deps> {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        f.debug_struct("FunctionTool")
200            .field("name", &self.name)
201            .field("description", &self.description)
202            .field("max_retries", &self.max_retries)
203            .finish()
204    }
205}
206
207/// Wrapper for sync function tools.
208///
209/// For tools that don't need async, this provides a simpler API.
210pub struct SyncFunctionTool<F, Deps = ()> {
211    name: String,
212    description: String,
213    parameters: JsonValue,
214    function: F,
215    max_retries: Option<u32>,
216    _phantom: PhantomData<fn() -> Deps>,
217}
218
219impl<F, Deps> SyncFunctionTool<F, Deps>
220where
221    F: Fn(&RunContext<Deps>, JsonValue) -> ToolResult + Send + Sync,
222{
223    /// Create a new sync function tool.
224    pub fn new(
225        name: impl Into<String>,
226        description: impl Into<String>,
227        parameters: impl Into<JsonValue>,
228        function: F,
229    ) -> Self {
230        Self {
231            name: name.into(),
232            description: description.into(),
233            parameters: parameters.into(),
234            function,
235            max_retries: None,
236            _phantom: PhantomData,
237        }
238    }
239
240    /// Set maximum retries.
241    #[must_use]
242    pub fn with_max_retries(mut self, retries: u32) -> Self {
243        self.max_retries = Some(retries);
244        self
245    }
246}
247
248#[async_trait]
249impl<F, Deps> Tool<Deps> for SyncFunctionTool<F, Deps>
250where
251    F: Fn(&RunContext<Deps>, JsonValue) -> ToolResult + Send + Sync,
252    Deps: Send + Sync,
253{
254    fn definition(&self) -> ToolDefinition {
255        ToolDefinition::new(&self.name, &self.description).with_parameters(self.parameters.clone())
256    }
257
258    async fn call(&self, ctx: &RunContext<Deps>, args: JsonValue) -> ToolResult {
259        (self.function)(ctx, args)
260    }
261
262    fn max_retries(&self) -> Option<u32> {
263        self.max_retries
264    }
265}
266
267impl<F, Deps> std::fmt::Debug for SyncFunctionTool<F, Deps> {
268    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        f.debug_struct("SyncFunctionTool")
270            .field("name", &self.name)
271            .field("description", &self.description)
272            .field("max_retries", &self.max_retries)
273            .finish()
274    }
275}
276
277/// Create a simple sync tool from a function.
278///
279/// # Example
280///
281/// ```ignore
282/// let tool = sync_tool(
283///     "echo",
284///     "Echo the input",
285///     |_ctx, args| {
286///         let msg = args["message"].as_str().unwrap_or("");
287///         Ok(ToolReturn::text(msg))
288///     },
289/// );
290/// ```
291pub fn sync_tool<F, Deps>(
292    name: impl Into<String>,
293    description: impl Into<String>,
294    function: F,
295) -> SyncFunctionTool<F, Deps>
296where
297    F: Fn(&RunContext<Deps>, JsonValue) -> ToolResult + Send + Sync,
298{
299    SyncFunctionTool::new(
300        name,
301        description,
302        SchemaBuilder::new()
303            .build()
304            .expect("SchemaBuilder JSON serialization failed"),
305        function,
306    )
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use crate::ToolReturn;
313
314    #[derive(Debug, Clone, Default)]
315    struct TestDeps;
316
317    struct TestTool;
318
319    #[async_trait]
320    impl Tool<TestDeps> for TestTool {
321        fn definition(&self) -> ToolDefinition {
322            ToolDefinition::new("test", "Test tool").with_parameters(
323                SchemaBuilder::new()
324                    .integer("x", "A number", true)
325                    .build()
326                    .expect("SchemaBuilder JSON serialization failed"),
327            )
328        }
329
330        async fn call(&self, _ctx: &RunContext<TestDeps>, args: JsonValue) -> ToolResult {
331            let x = args["x"].as_i64().unwrap_or(0);
332            Ok(ToolReturn::text(format!("x = {x}")))
333        }
334
335        fn max_retries(&self) -> Option<u32> {
336            Some(5)
337        }
338    }
339
340    #[tokio::test]
341    async fn test_tool_trait() {
342        let tool = TestTool;
343        let ctx = RunContext::new(TestDeps, "test-model");
344
345        assert_eq!(tool.name(), "test");
346        assert_eq!(tool.description(), "Test tool");
347        assert_eq!(tool.max_retries(), Some(5));
348
349        let result = tool.call(&ctx, serde_json::json!({"x": 42})).await.unwrap();
350        assert_eq!(result.as_text(), Some("x = 42"));
351    }
352
353    #[tokio::test]
354    async fn test_sync_function_tool() {
355        let tool = SyncFunctionTool::new(
356            "add",
357            "Add numbers",
358            SchemaBuilder::new()
359                .number("a", "First", true)
360                .number("b", "Second", true)
361                .build()
362                .expect("SchemaBuilder JSON serialization failed"),
363            |_ctx: &RunContext<()>, args: JsonValue| {
364                let a = args["a"].as_f64().unwrap_or(0.0);
365                let b = args["b"].as_f64().unwrap_or(0.0);
366                Ok(ToolReturn::text(format!("{}", a + b)))
367            },
368        );
369
370        let ctx = RunContext::minimal("test");
371        let result = tool
372            .call(&ctx, serde_json::json!({"a": 1.5, "b": 2.5}))
373            .await
374            .unwrap();
375        assert_eq!(result.as_text(), Some("4"));
376    }
377
378    #[tokio::test]
379    async fn test_tool_prepare() {
380        let tool = TestTool;
381        let ctx = RunContext::new(TestDeps, "test");
382        let def = tool.definition();
383        let prepared = tool.prepare(&ctx, def.clone()).await;
384        assert!(prepared.is_some());
385        assert_eq!(prepared.unwrap().name, def.name);
386    }
387
388    #[test]
389    fn test_sync_tool_helper() {
390        let tool = sync_tool::<_, ()>("echo", "Echo", |_ctx, args| {
391            let msg = args["message"].as_str().unwrap_or("default");
392            Ok(ToolReturn::text(msg))
393        });
394        assert_eq!(tool.name, "echo");
395    }
396}