serdes_ai_toolsets/
renamed.rs1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14pub struct RenamedToolset<T, Deps = ()> {
30 inner: T,
31 name_map: HashMap<String, String>,
33 _phantom: PhantomData<fn() -> Deps>,
34}
35
36impl<T, Deps> RenamedToolset<T, Deps>
37where
38 T: AbstractToolset<Deps>,
39{
40 pub fn new(inner: T) -> Self {
42 Self {
43 inner,
44 name_map: HashMap::new(),
45 _phantom: PhantomData,
46 }
47 }
48
49 pub fn with_map(inner: T, name_map: HashMap<String, String>) -> Self {
51 Self {
52 inner,
53 name_map,
54 _phantom: PhantomData,
55 }
56 }
57
58 #[must_use]
64 pub fn rename(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
65 self.name_map.insert(to.into(), from.into());
67 self
68 }
69
70 #[must_use]
72 pub fn inner(&self) -> &T {
73 &self.inner
74 }
75
76 #[must_use]
78 pub fn name_map(&self) -> &HashMap<String, String> {
79 &self.name_map
80 }
81
82 fn original_name<'a>(&'a self, new_name: &'a str) -> &'a str {
84 self.name_map
85 .get(new_name)
86 .map(|s| s.as_str())
87 .unwrap_or(new_name)
88 }
89
90 fn new_name(&self, original: &str) -> String {
92 for (new, old) in &self.name_map {
94 if old == original {
95 return new.clone();
96 }
97 }
98 original.to_string()
99 }
100}
101
102#[async_trait]
103impl<T, Deps> AbstractToolset<Deps> for RenamedToolset<T, Deps>
104where
105 T: AbstractToolset<Deps>,
106 Deps: Send + Sync,
107{
108 fn id(&self) -> Option<&str> {
109 self.inner.id()
110 }
111
112 fn type_name(&self) -> &'static str {
113 "RenamedToolset"
114 }
115
116 fn label(&self) -> String {
117 format!("RenamedToolset({})", self.inner.label())
118 }
119
120 async fn get_tools(
121 &self,
122 ctx: &RunContext<Deps>,
123 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
124 let inner_tools = self.inner.get_tools(ctx).await?;
125
126 Ok(inner_tools
127 .into_iter()
128 .map(|(original_name, mut tool)| {
129 let new_name = self.new_name(&original_name);
130 tool.tool_def.name = new_name.clone();
131 (new_name, tool)
132 })
133 .collect())
134 }
135
136 async fn call_tool(
137 &self,
138 name: &str,
139 args: JsonValue,
140 ctx: &RunContext<Deps>,
141 tool: &ToolsetTool,
142 ) -> Result<ToolReturn, ToolError> {
143 let original_name = self.original_name(name);
144
145 let mut original_tool = tool.clone();
147 original_tool.tool_def.name = original_name.to_string();
148
149 self.inner
150 .call_tool(original_name, args, ctx, &original_tool)
151 .await
152 }
153
154 async fn enter(&self) -> Result<(), ToolError> {
155 self.inner.enter().await
156 }
157
158 async fn exit(&self) -> Result<(), ToolError> {
159 self.inner.exit().await
160 }
161}
162
163impl<T: std::fmt::Debug, Deps> std::fmt::Debug for RenamedToolset<T, Deps> {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 f.debug_struct("RenamedToolset")
166 .field("inner", &self.inner)
167 .field("name_map", &self.name_map)
168 .finish()
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::FunctionToolset;
176 use async_trait::async_trait;
177 use serdes_ai_tools::{Tool, ToolDefinition};
178
179 struct SearchTool;
180
181 #[async_trait]
182 impl Tool<()> for SearchTool {
183 fn definition(&self) -> ToolDefinition {
184 ToolDefinition::new("search", "Search for items")
185 }
186
187 async fn call(
188 &self,
189 _ctx: &RunContext<()>,
190 _args: JsonValue,
191 ) -> Result<ToolReturn, ToolError> {
192 Ok(ToolReturn::text("search result"))
193 }
194 }
195
196 struct QueryTool;
197
198 #[async_trait]
199 impl Tool<()> for QueryTool {
200 fn definition(&self) -> ToolDefinition {
201 ToolDefinition::new("query", "Query the database")
202 }
203
204 async fn call(
205 &self,
206 _ctx: &RunContext<()>,
207 _args: JsonValue,
208 ) -> Result<ToolReturn, ToolError> {
209 Ok(ToolReturn::text("query result"))
210 }
211 }
212
213 #[test]
214 fn test_original_name() {
215 let toolset = FunctionToolset::new().tool(SearchTool);
216 let renamed = RenamedToolset::new(toolset).rename("search", "find");
217
218 assert_eq!(renamed.original_name("find"), "search");
219 assert_eq!(renamed.original_name("other"), "other");
220 }
221
222 #[test]
223 fn test_new_name() {
224 let toolset = FunctionToolset::new().tool(SearchTool);
225 let renamed = RenamedToolset::new(toolset).rename("search", "find");
226
227 assert_eq!(renamed.new_name("search"), "find");
228 assert_eq!(renamed.new_name("other"), "other");
229 }
230
231 #[tokio::test]
232 async fn test_renamed_toolset_get_tools() {
233 let toolset = FunctionToolset::new().tool(SearchTool).tool(QueryTool);
234 let renamed = RenamedToolset::new(toolset).rename("search", "find_items");
235
236 let ctx = RunContext::minimal("test");
237 let tools = renamed.get_tools(&ctx).await.unwrap();
238
239 assert_eq!(tools.len(), 2);
240 assert!(tools.contains_key("find_items"));
241 assert!(tools.contains_key("query"));
242 assert!(!tools.contains_key("search"));
243 }
244
245 #[tokio::test]
246 async fn test_renamed_toolset_call_tool() {
247 let toolset = FunctionToolset::new().tool(SearchTool);
248 let renamed = RenamedToolset::new(toolset).rename("search", "find_items");
249
250 let ctx = RunContext::minimal("test");
251 let tools = renamed.get_tools(&ctx).await.unwrap();
252 let tool = tools.get("find_items").unwrap();
253
254 let result = renamed
255 .call_tool("find_items", serde_json::json!({}), &ctx, tool)
256 .await
257 .unwrap();
258
259 assert_eq!(result.as_text(), Some("search result"));
260 }
261
262 #[tokio::test]
263 async fn test_renamed_toolset_multiple_renames() {
264 let toolset = FunctionToolset::new().tool(SearchTool).tool(QueryTool);
265 let renamed = RenamedToolset::new(toolset)
266 .rename("search", "find")
267 .rename("query", "lookup");
268
269 let ctx = RunContext::minimal("test");
270 let tools = renamed.get_tools(&ctx).await.unwrap();
271
272 assert!(tools.contains_key("find"));
273 assert!(tools.contains_key("lookup"));
274 }
275}