Skip to main content

serdes_ai_toolsets/
function.rs

1//! Function-based toolset implementation.
2//!
3//! This module provides `FunctionToolset`, which wraps a `ToolRegistry`
4//! and adapts it to the `AbstractToolset` interface.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{
9    ObjectJsonSchema, RunContext, Tool, ToolDefinition, ToolError, ToolRegistry, ToolReturn,
10};
11use std::collections::HashMap;
12use std::future::Future;
13use std::marker::PhantomData;
14use std::pin::Pin;
15
16use crate::{AbstractToolset, ToolsetTool};
17
18/// A toolset backed by function-based tools.
19///
20/// This wraps a `ToolRegistry` and provides the `AbstractToolset` interface.
21///
22/// # Example
23///
24/// ```ignore
25/// use serdes_ai_toolsets::FunctionToolset;
26/// use serdes_ai_tools::{SyncFunctionTool, ObjectJsonSchema};
27///
28/// let toolset = FunctionToolset::new()
29///     .with_id("my_tools")
30///     .tool(my_tool)
31///     .tool(another_tool);
32/// ```
33pub struct FunctionToolset<Deps = ()> {
34    id: Option<String>,
35    registry: ToolRegistry<Deps>,
36    max_retries: u32,
37}
38
39impl<Deps> FunctionToolset<Deps> {
40    /// Create a new empty function toolset.
41    #[must_use]
42    pub fn new() -> Self {
43        Self {
44            id: None,
45            registry: ToolRegistry::new(),
46            max_retries: 3,
47        }
48    }
49
50    /// Set the toolset ID.
51    #[must_use]
52    pub fn with_id(mut self, id: impl Into<String>) -> Self {
53        self.id = Some(id.into());
54        self
55    }
56
57    /// Set the default max retries for tools.
58    #[must_use]
59    pub fn with_max_retries(mut self, retries: u32) -> Self {
60        self.max_retries = retries;
61        self
62    }
63
64    /// Add a tool to the toolset.
65    #[must_use]
66    pub fn tool<T: Tool<Deps> + 'static>(mut self, tool: T) -> Self {
67        self.registry.register(tool);
68        self
69    }
70
71    /// Add multiple tools to the toolset.
72    #[must_use]
73    pub fn tools<I, T>(mut self, tools: I) -> Self
74    where
75        I: IntoIterator<Item = T>,
76        T: Tool<Deps> + 'static,
77    {
78        for tool in tools {
79            self.registry.register(tool);
80        }
81        self
82    }
83
84    /// Get the underlying registry.
85    #[must_use]
86    pub fn registry(&self) -> &ToolRegistry<Deps> {
87        &self.registry
88    }
89
90    /// Get mutable access to the registry.
91    pub fn registry_mut(&mut self) -> &mut ToolRegistry<Deps> {
92        &mut self.registry
93    }
94
95    /// Get the number of tools.
96    #[must_use]
97    pub fn len(&self) -> usize {
98        self.registry.len()
99    }
100
101    /// Check if the toolset is empty.
102    #[must_use]
103    pub fn is_empty(&self) -> bool {
104        self.registry.is_empty()
105    }
106}
107
108impl<Deps> Default for FunctionToolset<Deps> {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[async_trait]
115impl<Deps: Send + Sync + 'static> AbstractToolset<Deps> for FunctionToolset<Deps> {
116    fn id(&self) -> Option<&str> {
117        self.id.as_deref()
118    }
119
120    fn type_name(&self) -> &'static str {
121        "FunctionToolset"
122    }
123
124    async fn get_tools(
125        &self,
126        ctx: &RunContext<Deps>,
127    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
128        let defs = self.registry.prepared_definitions(ctx).await;
129        Ok(defs
130            .into_iter()
131            .map(|def| {
132                let name = def.name.clone();
133                let max_retries = self.registry.max_retries(&name).unwrap_or(self.max_retries);
134                (
135                    name,
136                    ToolsetTool {
137                        toolset_id: self.id.clone(),
138                        tool_def: def,
139                        max_retries,
140                    },
141                )
142            })
143            .collect())
144    }
145
146    async fn call_tool(
147        &self,
148        name: &str,
149        args: JsonValue,
150        ctx: &RunContext<Deps>,
151        _tool: &ToolsetTool,
152    ) -> Result<ToolReturn, ToolError> {
153        self.registry.call(name, ctx, args).await
154    }
155}
156
157impl<Deps> std::fmt::Debug for FunctionToolset<Deps> {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("FunctionToolset")
160            .field("id", &self.id)
161            .field("tool_count", &self.registry.len())
162            .field("max_retries", &self.max_retries)
163            .finish()
164    }
165}
166
167/// Wrapper for async function tools with explicit type.
168///
169/// This provides a simpler way to add async functions as tools.
170pub struct AsyncFnTool<F, Deps> {
171    name: String,
172    description: String,
173    parameters: ObjectJsonSchema,
174    function: F,
175    max_retries: Option<u32>,
176    _phantom: PhantomData<fn() -> Deps>,
177}
178
179impl<F, Deps> AsyncFnTool<F, Deps> {
180    /// Create a new async function tool.
181    pub fn new(
182        name: impl Into<String>,
183        description: impl Into<String>,
184        parameters: ObjectJsonSchema,
185        function: F,
186    ) -> Self {
187        Self {
188            name: name.into(),
189            description: description.into(),
190            parameters,
191            function,
192            max_retries: None,
193            _phantom: PhantomData,
194        }
195    }
196
197    /// Set max retries.
198    #[must_use]
199    pub fn with_max_retries(mut self, retries: u32) -> Self {
200        self.max_retries = Some(retries);
201        self
202    }
203}
204
205type PinnedToolFuture = Pin<Box<dyn Future<Output = Result<ToolReturn, ToolError>> + Send>>;
206
207#[async_trait]
208impl<F, Deps> Tool<Deps> for AsyncFnTool<F, Deps>
209where
210    F: for<'a> Fn(&'a RunContext<Deps>, JsonValue) -> PinnedToolFuture + Send + Sync,
211    Deps: Send + Sync,
212{
213    fn definition(&self) -> ToolDefinition {
214        ToolDefinition::new(&self.name, &self.description).with_parameters(self.parameters.clone())
215    }
216
217    async fn call(&self, ctx: &RunContext<Deps>, args: JsonValue) -> Result<ToolReturn, ToolError> {
218        (self.function)(ctx, args).await
219    }
220
221    fn max_retries(&self) -> Option<u32> {
222        self.max_retries
223    }
224}
225
226impl<F, Deps> std::fmt::Debug for AsyncFnTool<F, Deps> {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        f.debug_struct("AsyncFnTool")
229            .field("name", &self.name)
230            .field("max_retries", &self.max_retries)
231            .finish()
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use serdes_ai_tools::PropertySchema;
239
240    struct EchoTool;
241
242    #[async_trait]
243    impl Tool<()> for EchoTool {
244        fn definition(&self) -> ToolDefinition {
245            ToolDefinition::new("echo", "Echo the message").with_parameters(
246                ObjectJsonSchema::new().with_property(
247                    "msg",
248                    PropertySchema::string("Message").build(),
249                    true,
250                ),
251            )
252        }
253
254        async fn call(
255            &self,
256            _ctx: &RunContext<()>,
257            args: JsonValue,
258        ) -> Result<ToolReturn, ToolError> {
259            let msg = args["msg"].as_str().unwrap_or("<none>");
260            Ok(ToolReturn::text(msg))
261        }
262    }
263
264    #[test]
265    fn test_function_toolset_new() {
266        let toolset = FunctionToolset::<()>::new();
267        assert!(toolset.is_empty());
268        assert!(toolset.id().is_none());
269    }
270
271    #[test]
272    fn test_function_toolset_with_id() {
273        let toolset = FunctionToolset::<()>::new().with_id("my_tools");
274        assert_eq!(toolset.id(), Some("my_tools"));
275    }
276
277    #[test]
278    fn test_function_toolset_add_tool() {
279        let toolset = FunctionToolset::new().tool(EchoTool);
280        assert_eq!(toolset.len(), 1);
281    }
282
283    #[tokio::test]
284    async fn test_function_toolset_get_tools() {
285        let toolset: FunctionToolset<()> = FunctionToolset::new().with_id("test").tool(EchoTool);
286
287        let ctx = RunContext::minimal("test");
288        let tools = toolset.get_tools(&ctx).await.unwrap();
289
290        assert_eq!(tools.len(), 1);
291        assert!(tools.contains_key("echo"));
292        let echo = tools.get("echo").unwrap();
293        assert_eq!(echo.toolset_id, Some("test".to_string()));
294    }
295
296    #[tokio::test]
297    async fn test_function_toolset_call_tool() {
298        let toolset = FunctionToolset::new().tool(EchoTool);
299        let ctx = RunContext::minimal("test");
300        let tools = toolset.get_tools(&ctx).await.unwrap();
301        let echo_tool = tools.get("echo").unwrap();
302
303        let result = toolset
304            .call_tool("echo", serde_json::json!({"msg": "hello"}), &ctx, echo_tool)
305            .await
306            .unwrap();
307
308        assert_eq!(result.as_text(), Some("hello"));
309    }
310
311    #[test]
312    fn test_function_toolset_debug() {
313        let toolset = FunctionToolset::new()
314            .with_id("debug_test")
315            .with_max_retries(5)
316            .tool(EchoTool);
317
318        let debug = format!("{:?}", toolset);
319        assert!(debug.contains("FunctionToolset"));
320        assert!(debug.contains("debug_test"));
321    }
322}