Skip to main content

serdes_ai_toolsets/
prefixed.rs

1//! Prefixed toolset implementation.
2//!
3//! This module provides `PrefixedToolset`, which adds a prefix to all
4//! tool names from the wrapped toolset.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14/// Adds a prefix to all tool names.
15///
16/// This is useful for avoiding name conflicts when combining multiple
17/// toolsets that might have tools with the same name.
18///
19/// # Example
20///
21/// ```ignore
22/// use serdes_ai_toolsets::{PrefixedToolset, FunctionToolset};
23///
24/// let tools1 = FunctionToolset::new().tool(search_tool);
25/// let tools2 = FunctionToolset::new().tool(search_tool);
26///
27/// // Prefix to avoid conflicts
28/// let prefixed1 = PrefixedToolset::new(tools1, "web");
29/// let prefixed2 = PrefixedToolset::new(tools2, "local");
30///
31/// // Now we have "web_search" and "local_search"
32/// ```
33pub struct PrefixedToolset<T, Deps = ()> {
34    inner: T,
35    prefix: String,
36    separator: String,
37    _phantom: PhantomData<fn() -> Deps>,
38}
39
40impl<T, Deps> PrefixedToolset<T, Deps>
41where
42    T: AbstractToolset<Deps>,
43{
44    /// Create a new prefixed toolset with default separator "_".
45    pub fn new(inner: T, prefix: impl Into<String>) -> Self {
46        Self {
47            inner,
48            prefix: prefix.into(),
49            separator: "_".to_string(),
50            _phantom: PhantomData,
51        }
52    }
53
54    /// Set a custom separator (default is "_").
55    #[must_use]
56    pub fn with_separator(mut self, sep: impl Into<String>) -> Self {
57        self.separator = sep.into();
58        self
59    }
60
61    /// Get the prefix.
62    #[must_use]
63    pub fn prefix(&self) -> &str {
64        &self.prefix
65    }
66
67    /// Get the separator.
68    #[must_use]
69    pub fn separator(&self) -> &str {
70        &self.separator
71    }
72
73    /// Get the inner toolset.
74    #[must_use]
75    pub fn inner(&self) -> &T {
76        &self.inner
77    }
78
79    /// Create the prefixed name.
80    fn prefixed_name(&self, name: &str) -> String {
81        format!("{}{}{}", self.prefix, self.separator, name)
82    }
83
84    /// Strip the prefix from a name.
85    fn strip_prefix(&self, prefixed: &str) -> Option<String> {
86        let prefix_with_sep = format!("{}{}", self.prefix, self.separator);
87        prefixed
88            .strip_prefix(&prefix_with_sep)
89            .map(|s| s.to_string())
90    }
91}
92
93#[async_trait]
94impl<T, Deps> AbstractToolset<Deps> for PrefixedToolset<T, Deps>
95where
96    T: AbstractToolset<Deps>,
97    Deps: Send + Sync,
98{
99    fn id(&self) -> Option<&str> {
100        self.inner.id()
101    }
102
103    fn type_name(&self) -> &'static str {
104        "PrefixedToolset"
105    }
106
107    fn label(&self) -> String {
108        format!(
109            "PrefixedToolset('{}{}', {})",
110            self.prefix,
111            self.separator,
112            self.inner.label()
113        )
114    }
115
116    async fn get_tools(
117        &self,
118        ctx: &RunContext<Deps>,
119    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
120        let inner_tools = self.inner.get_tools(ctx).await?;
121
122        Ok(inner_tools
123            .into_iter()
124            .map(|(name, mut tool)| {
125                let prefixed = self.prefixed_name(&name);
126                // Update the tool definition with the prefixed name
127                tool.tool_def.name = prefixed.clone();
128                (prefixed, tool)
129            })
130            .collect())
131    }
132
133    async fn call_tool(
134        &self,
135        name: &str,
136        args: JsonValue,
137        ctx: &RunContext<Deps>,
138        tool: &ToolsetTool,
139    ) -> Result<ToolReturn, ToolError> {
140        // Strip prefix to get the original tool name
141        let original_name = self.strip_prefix(name).ok_or_else(|| {
142            ToolError::not_found(format!(
143                "Tool '{}' does not have expected prefix '{}{}'",
144                name, self.prefix, self.separator
145            ))
146        })?;
147
148        // Create a modified tool with the original name
149        let mut original_tool = tool.clone();
150        original_tool.tool_def.name = original_name.clone();
151
152        self.inner
153            .call_tool(&original_name, args, ctx, &original_tool)
154            .await
155    }
156
157    async fn enter(&self) -> Result<(), ToolError> {
158        self.inner.enter().await
159    }
160
161    async fn exit(&self) -> Result<(), ToolError> {
162        self.inner.exit().await
163    }
164}
165
166impl<T: std::fmt::Debug, Deps> std::fmt::Debug for PrefixedToolset<T, Deps> {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.debug_struct("PrefixedToolset")
169            .field("prefix", &self.prefix)
170            .field("separator", &self.separator)
171            .field("inner", &self.inner)
172            .finish()
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::FunctionToolset;
180    use async_trait::async_trait;
181    use serdes_ai_tools::{Tool, ToolDefinition};
182
183    struct SearchTool;
184
185    #[async_trait]
186    impl Tool<()> for SearchTool {
187        fn definition(&self) -> ToolDefinition {
188            ToolDefinition::new("search", "Search for items")
189        }
190
191        async fn call(
192            &self,
193            _ctx: &RunContext<()>,
194            args: JsonValue,
195        ) -> Result<ToolReturn, ToolError> {
196            let query = args["query"].as_str().unwrap_or("*");
197            Ok(ToolReturn::text(format!("Searching for: {}", query)))
198        }
199    }
200
201    #[test]
202    fn test_prefixed_name() {
203        let toolset = FunctionToolset::new().tool(SearchTool);
204        let prefixed = PrefixedToolset::new(toolset, "web");
205
206        assert_eq!(prefixed.prefixed_name("search"), "web_search");
207    }
208
209    #[test]
210    fn test_strip_prefix() {
211        let toolset = FunctionToolset::new().tool(SearchTool);
212        let prefixed = PrefixedToolset::new(toolset, "web");
213
214        assert_eq!(
215            prefixed.strip_prefix("web_search"),
216            Some("search".to_string())
217        );
218        assert_eq!(prefixed.strip_prefix("local_search"), None);
219    }
220
221    #[test]
222    fn test_custom_separator() {
223        let toolset = FunctionToolset::new().tool(SearchTool);
224        let prefixed = PrefixedToolset::new(toolset, "web").with_separator("::");
225
226        assert_eq!(prefixed.prefixed_name("search"), "web::search");
227    }
228
229    #[tokio::test]
230    async fn test_prefixed_toolset_get_tools() {
231        let toolset = FunctionToolset::new().tool(SearchTool);
232        let prefixed = PrefixedToolset::new(toolset, "web");
233
234        let ctx = RunContext::minimal("test");
235        let tools = prefixed.get_tools(&ctx).await.unwrap();
236
237        assert_eq!(tools.len(), 1);
238        assert!(tools.contains_key("web_search"));
239        assert!(!tools.contains_key("search"));
240
241        let tool = tools.get("web_search").unwrap();
242        assert_eq!(tool.tool_def.name, "web_search");
243    }
244
245    #[tokio::test]
246    async fn test_prefixed_toolset_call_tool() {
247        let toolset = FunctionToolset::new().tool(SearchTool);
248        let prefixed = PrefixedToolset::new(toolset, "web");
249
250        let ctx = RunContext::minimal("test");
251        let tools = prefixed.get_tools(&ctx).await.unwrap();
252        let tool = tools.get("web_search").unwrap();
253
254        let result = prefixed
255            .call_tool(
256                "web_search",
257                serde_json::json!({"query": "rust"}),
258                &ctx,
259                tool,
260            )
261            .await
262            .unwrap();
263
264        assert!(result.as_text().unwrap().contains("rust"));
265    }
266
267    #[tokio::test]
268    async fn test_prefixed_toolset_wrong_prefix() {
269        let toolset = FunctionToolset::new().tool(SearchTool);
270        let prefixed = PrefixedToolset::new(toolset, "web");
271
272        let ctx = RunContext::minimal("test");
273        let fake_tool = ToolsetTool::new(ToolDefinition::new("local_search", "Local search"));
274
275        let result = prefixed
276            .call_tool("local_search", serde_json::json!({}), &ctx, &fake_tool)
277            .await;
278
279        assert!(result.is_err());
280    }
281}