1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{
9 ObjectJsonSchema, RunContext, Tool, ToolDefinition, ToolError, ToolRegistry, ToolReturn,
10};
11use std::collections::HashMap;
12use std::future::Future;
13use std::marker::PhantomData;
14use std::pin::Pin;
15
16use crate::{AbstractToolset, ToolsetTool};
17
18pub struct FunctionToolset<Deps = ()> {
34 id: Option<String>,
35 registry: ToolRegistry<Deps>,
36 max_retries: u32,
37}
38
39impl<Deps> FunctionToolset<Deps> {
40 #[must_use]
42 pub fn new() -> Self {
43 Self {
44 id: None,
45 registry: ToolRegistry::new(),
46 max_retries: 3,
47 }
48 }
49
50 #[must_use]
52 pub fn with_id(mut self, id: impl Into<String>) -> Self {
53 self.id = Some(id.into());
54 self
55 }
56
57 #[must_use]
59 pub fn with_max_retries(mut self, retries: u32) -> Self {
60 self.max_retries = retries;
61 self
62 }
63
64 #[must_use]
66 pub fn tool<T: Tool<Deps> + 'static>(mut self, tool: T) -> Self {
67 self.registry.register(tool);
68 self
69 }
70
71 #[must_use]
73 pub fn tools<I, T>(mut self, tools: I) -> Self
74 where
75 I: IntoIterator<Item = T>,
76 T: Tool<Deps> + 'static,
77 {
78 for tool in tools {
79 self.registry.register(tool);
80 }
81 self
82 }
83
84 #[must_use]
86 pub fn registry(&self) -> &ToolRegistry<Deps> {
87 &self.registry
88 }
89
90 pub fn registry_mut(&mut self) -> &mut ToolRegistry<Deps> {
92 &mut self.registry
93 }
94
95 #[must_use]
97 pub fn len(&self) -> usize {
98 self.registry.len()
99 }
100
101 #[must_use]
103 pub fn is_empty(&self) -> bool {
104 self.registry.is_empty()
105 }
106}
107
108impl<Deps> Default for FunctionToolset<Deps> {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114#[async_trait]
115impl<Deps: Send + Sync + 'static> AbstractToolset<Deps> for FunctionToolset<Deps> {
116 fn id(&self) -> Option<&str> {
117 self.id.as_deref()
118 }
119
120 fn type_name(&self) -> &'static str {
121 "FunctionToolset"
122 }
123
124 async fn get_tools(
125 &self,
126 ctx: &RunContext<Deps>,
127 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
128 let defs = self.registry.prepared_definitions(ctx).await;
129 Ok(defs
130 .into_iter()
131 .map(|def| {
132 let name = def.name.clone();
133 let max_retries = self.registry.max_retries(&name).unwrap_or(self.max_retries);
134 (
135 name,
136 ToolsetTool {
137 toolset_id: self.id.clone(),
138 tool_def: def,
139 max_retries,
140 },
141 )
142 })
143 .collect())
144 }
145
146 async fn call_tool(
147 &self,
148 name: &str,
149 args: JsonValue,
150 ctx: &RunContext<Deps>,
151 _tool: &ToolsetTool,
152 ) -> Result<ToolReturn, ToolError> {
153 self.registry.call(name, ctx, args).await
154 }
155}
156
157impl<Deps> std::fmt::Debug for FunctionToolset<Deps> {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("FunctionToolset")
160 .field("id", &self.id)
161 .field("tool_count", &self.registry.len())
162 .field("max_retries", &self.max_retries)
163 .finish()
164 }
165}
166
167pub struct AsyncFnTool<F, Deps> {
171 name: String,
172 description: String,
173 parameters: ObjectJsonSchema,
174 function: F,
175 max_retries: Option<u32>,
176 _phantom: PhantomData<fn() -> Deps>,
177}
178
179impl<F, Deps> AsyncFnTool<F, Deps> {
180 pub fn new(
182 name: impl Into<String>,
183 description: impl Into<String>,
184 parameters: ObjectJsonSchema,
185 function: F,
186 ) -> Self {
187 Self {
188 name: name.into(),
189 description: description.into(),
190 parameters,
191 function,
192 max_retries: None,
193 _phantom: PhantomData,
194 }
195 }
196
197 #[must_use]
199 pub fn with_max_retries(mut self, retries: u32) -> Self {
200 self.max_retries = Some(retries);
201 self
202 }
203}
204
205type PinnedToolFuture = Pin<Box<dyn Future<Output = Result<ToolReturn, ToolError>> + Send>>;
206
207#[async_trait]
208impl<F, Deps> Tool<Deps> for AsyncFnTool<F, Deps>
209where
210 F: for<'a> Fn(&'a RunContext<Deps>, JsonValue) -> PinnedToolFuture + Send + Sync,
211 Deps: Send + Sync,
212{
213 fn definition(&self) -> ToolDefinition {
214 ToolDefinition::new(&self.name, &self.description).with_parameters(self.parameters.clone())
215 }
216
217 async fn call(&self, ctx: &RunContext<Deps>, args: JsonValue) -> Result<ToolReturn, ToolError> {
218 (self.function)(ctx, args).await
219 }
220
221 fn max_retries(&self) -> Option<u32> {
222 self.max_retries
223 }
224}
225
226impl<F, Deps> std::fmt::Debug for AsyncFnTool<F, Deps> {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 f.debug_struct("AsyncFnTool")
229 .field("name", &self.name)
230 .field("max_retries", &self.max_retries)
231 .finish()
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use serdes_ai_tools::PropertySchema;
239
240 struct EchoTool;
241
242 #[async_trait]
243 impl Tool<()> for EchoTool {
244 fn definition(&self) -> ToolDefinition {
245 ToolDefinition::new("echo", "Echo the message").with_parameters(
246 ObjectJsonSchema::new().with_property(
247 "msg",
248 PropertySchema::string("Message").build(),
249 true,
250 ),
251 )
252 }
253
254 async fn call(
255 &self,
256 _ctx: &RunContext<()>,
257 args: JsonValue,
258 ) -> Result<ToolReturn, ToolError> {
259 let msg = args["msg"].as_str().unwrap_or("<none>");
260 Ok(ToolReturn::text(msg))
261 }
262 }
263
264 #[test]
265 fn test_function_toolset_new() {
266 let toolset = FunctionToolset::<()>::new();
267 assert!(toolset.is_empty());
268 assert!(toolset.id().is_none());
269 }
270
271 #[test]
272 fn test_function_toolset_with_id() {
273 let toolset = FunctionToolset::<()>::new().with_id("my_tools");
274 assert_eq!(toolset.id(), Some("my_tools"));
275 }
276
277 #[test]
278 fn test_function_toolset_add_tool() {
279 let toolset = FunctionToolset::new().tool(EchoTool);
280 assert_eq!(toolset.len(), 1);
281 }
282
283 #[tokio::test]
284 async fn test_function_toolset_get_tools() {
285 let toolset: FunctionToolset<()> = FunctionToolset::new().with_id("test").tool(EchoTool);
286
287 let ctx = RunContext::minimal("test");
288 let tools = toolset.get_tools(&ctx).await.unwrap();
289
290 assert_eq!(tools.len(), 1);
291 assert!(tools.contains_key("echo"));
292 let echo = tools.get("echo").unwrap();
293 assert_eq!(echo.toolset_id, Some("test".to_string()));
294 }
295
296 #[tokio::test]
297 async fn test_function_toolset_call_tool() {
298 let toolset = FunctionToolset::new().tool(EchoTool);
299 let ctx = RunContext::minimal("test");
300 let tools = toolset.get_tools(&ctx).await.unwrap();
301 let echo_tool = tools.get("echo").unwrap();
302
303 let result = toolset
304 .call_tool("echo", serde_json::json!({"msg": "hello"}), &ctx, echo_tool)
305 .await
306 .unwrap();
307
308 assert_eq!(result.as_text(), Some("hello"));
309 }
310
311 #[test]
312 fn test_function_toolset_debug() {
313 let toolset = FunctionToolset::new()
314 .with_id("debug_test")
315 .with_max_retries(5)
316 .tool(EchoTool);
317
318 let debug = format!("{:?}", toolset);
319 assert!(debug.contains("FunctionToolset"));
320 assert!(debug.contains("debug_test"));
321 }
322}