Skip to main content

serdes_ai_toolsets/
combined.rs

1//! Combined toolset for merging multiple toolsets.
2//!
3//! This module provides `CombinedToolset`, which merges multiple toolsets
4//! into a single unified toolset.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
9use std::collections::HashMap;
10
11use crate::{AbstractToolset, BoxedToolset, ToolsetTool};
12
13/// Combines multiple toolsets into one.
14///
15/// This allows treating multiple toolsets as a single collection.
16/// It handles name conflicts by returning an error with helpful hints.
17///
18/// # Example
19///
20/// ```ignore
21/// use serdes_ai_toolsets::{CombinedToolset, FunctionToolset};
22///
23/// let toolset1 = FunctionToolset::with_id("tools1").tool(tool_a);
24/// let toolset2 = FunctionToolset::with_id("tools2").tool(tool_b);
25///
26/// let combined = CombinedToolset::new()
27///     .with_toolset(toolset1)
28///     .with_toolset(toolset2);
29/// ```
30pub struct CombinedToolset<Deps = ()> {
31    id: Option<String>,
32    toolsets: Vec<BoxedToolset<Deps>>,
33}
34
35impl<Deps> CombinedToolset<Deps> {
36    /// Create a new empty combined toolset.
37    #[must_use]
38    pub fn new() -> Self {
39        Self {
40            id: None,
41            toolsets: Vec::new(),
42        }
43    }
44
45    /// Create a combined toolset with an ID.
46    #[must_use]
47    pub fn with_id(id: impl Into<String>) -> Self {
48        Self {
49            id: Some(id.into()),
50            toolsets: Vec::new(),
51        }
52    }
53
54    /// Add a toolset.
55    #[must_use]
56    pub fn with_toolset<T: AbstractToolset<Deps> + 'static>(mut self, toolset: T) -> Self {
57        self.toolsets.push(Box::new(toolset));
58        self
59    }
60
61    /// Add a boxed toolset.
62    #[must_use]
63    pub fn add_boxed(mut self, toolset: BoxedToolset<Deps>) -> Self {
64        self.toolsets.push(toolset);
65        self
66    }
67
68    /// Add multiple toolsets.
69    #[must_use]
70    pub fn toolsets<I, T>(mut self, toolsets: I) -> Self
71    where
72        I: IntoIterator<Item = T>,
73        T: AbstractToolset<Deps> + 'static,
74    {
75        for toolset in toolsets {
76            self.toolsets.push(Box::new(toolset));
77        }
78        self
79    }
80
81    /// Get the number of contained toolsets.
82    #[must_use]
83    pub fn toolset_count(&self) -> usize {
84        self.toolsets.len()
85    }
86
87    /// Check if empty.
88    #[must_use]
89    pub fn is_empty(&self) -> bool {
90        self.toolsets.is_empty()
91    }
92}
93
94impl<Deps> Default for CombinedToolset<Deps> {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100/// Track which toolset owns which tool.
101#[derive(Clone)]
102struct ToolOwnership {
103    toolset_index: usize,
104    tool: ToolsetTool,
105}
106
107#[async_trait]
108impl<Deps: Send + Sync + 'static> AbstractToolset<Deps> for CombinedToolset<Deps> {
109    fn id(&self) -> Option<&str> {
110        self.id.as_deref()
111    }
112
113    fn type_name(&self) -> &'static str {
114        "CombinedToolset"
115    }
116
117    fn tool_name_conflict_hint(&self) -> String {
118        "Use PrefixedToolset to add prefixes to tool names from different toolsets.".to_string()
119    }
120
121    async fn get_tools(
122        &self,
123        ctx: &RunContext<Deps>,
124    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
125        let mut all_tools: HashMap<String, ToolOwnership> = HashMap::new();
126        let mut conflicts: Vec<(String, String, String)> = Vec::new();
127
128        for (idx, toolset) in self.toolsets.iter().enumerate() {
129            let tools = toolset.get_tools(ctx).await?;
130
131            for (name, tool) in tools {
132                if let Some(existing) = all_tools.get(&name) {
133                    // Track conflict
134                    let existing_label = self.toolsets[existing.toolset_index].label();
135                    let new_label = toolset.label();
136                    conflicts.push((name.clone(), existing_label, new_label));
137                } else {
138                    all_tools.insert(
139                        name,
140                        ToolOwnership {
141                            toolset_index: idx,
142                            tool,
143                        },
144                    );
145                }
146            }
147        }
148
149        if !conflicts.is_empty() {
150            let conflict_msgs: Vec<String> = conflicts
151                .iter()
152                .map(|(name, t1, t2)| format!("  - '{}' exists in {} and {}", name, t1, t2))
153                .collect();
154
155            return Err(ToolError::execution_failed(format!(
156                "Tool name conflicts in {}:\n{}\n\nHint: {}",
157                self.label(),
158                conflict_msgs.join("\n"),
159                self.tool_name_conflict_hint()
160            )));
161        }
162
163        Ok(all_tools
164            .into_iter()
165            .map(|(name, ownership)| (name, ownership.tool))
166            .collect())
167    }
168
169    async fn call_tool(
170        &self,
171        name: &str,
172        args: JsonValue,
173        ctx: &RunContext<Deps>,
174        tool: &ToolsetTool,
175    ) -> Result<ToolReturn, ToolError> {
176        // Find which toolset has this tool
177        for toolset in &self.toolsets {
178            let tools = toolset.get_tools(ctx).await?;
179            if tools.contains_key(name) {
180                return toolset.call_tool(name, args, ctx, tool).await;
181            }
182        }
183
184        Err(ToolError::not_found(format!(
185            "Tool '{}' not found in {}",
186            name,
187            self.label()
188        )))
189    }
190
191    async fn enter(&self) -> Result<(), ToolError> {
192        for toolset in &self.toolsets {
193            toolset.enter().await?;
194        }
195        Ok(())
196    }
197
198    async fn exit(&self) -> Result<(), ToolError> {
199        // Exit in reverse order
200        for toolset in self.toolsets.iter().rev() {
201            toolset.exit().await?;
202        }
203        Ok(())
204    }
205}
206
207impl<Deps> std::fmt::Debug for CombinedToolset<Deps> {
208    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209        f.debug_struct("CombinedToolset")
210            .field("id", &self.id)
211            .field("toolset_count", &self.toolsets.len())
212            .finish()
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::FunctionToolset;
220    use async_trait::async_trait;
221    use serdes_ai_tools::{Tool, ToolDefinition};
222
223    struct ToolA;
224
225    #[async_trait]
226    impl Tool<()> for ToolA {
227        fn definition(&self) -> ToolDefinition {
228            ToolDefinition::new("tool_a", "Tool A")
229        }
230
231        async fn call(
232            &self,
233            _ctx: &RunContext<()>,
234            _args: JsonValue,
235        ) -> Result<ToolReturn, ToolError> {
236            Ok(ToolReturn::text("A"))
237        }
238    }
239
240    struct ToolB;
241
242    #[async_trait]
243    impl Tool<()> for ToolB {
244        fn definition(&self) -> ToolDefinition {
245            ToolDefinition::new("tool_b", "Tool B")
246        }
247
248        async fn call(
249            &self,
250            _ctx: &RunContext<()>,
251            _args: JsonValue,
252        ) -> Result<ToolReturn, ToolError> {
253            Ok(ToolReturn::text("B"))
254        }
255    }
256
257    // Tool with same name as ToolA for conflict testing
258    struct ConflictingTool;
259
260    #[async_trait]
261    impl Tool<()> for ConflictingTool {
262        fn definition(&self) -> ToolDefinition {
263            ToolDefinition::new("tool_a", "Conflicting Tool A") // Same name!
264        }
265
266        async fn call(
267            &self,
268            _ctx: &RunContext<()>,
269            _args: JsonValue,
270        ) -> Result<ToolReturn, ToolError> {
271            Ok(ToolReturn::text("Conflict"))
272        }
273    }
274
275    #[test]
276    fn test_combined_toolset_new() {
277        let toolset = CombinedToolset::<()>::new();
278        assert!(toolset.is_empty());
279        assert_eq!(toolset.toolset_count(), 0);
280    }
281
282    #[test]
283    fn test_combined_toolset_with_id() {
284        let toolset = CombinedToolset::<()>::with_id("combined");
285        assert_eq!(toolset.id(), Some("combined"));
286    }
287
288    #[tokio::test]
289    async fn test_combined_toolset_merges_tools() {
290        let ts1 = FunctionToolset::new().with_id("ts1").tool(ToolA);
291        let ts2 = FunctionToolset::new().with_id("ts2").tool(ToolB);
292
293        let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
294
295        let ctx = RunContext::minimal("test");
296        let tools = combined.get_tools(&ctx).await.unwrap();
297
298        assert_eq!(tools.len(), 2);
299        assert!(tools.contains_key("tool_a"));
300        assert!(tools.contains_key("tool_b"));
301    }
302
303    #[tokio::test]
304    async fn test_combined_toolset_call_tool() {
305        let ts1 = FunctionToolset::new().tool(ToolA);
306        let ts2 = FunctionToolset::new().tool(ToolB);
307
308        let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
309
310        let ctx = RunContext::minimal("test");
311        let tools = combined.get_tools(&ctx).await.unwrap();
312        let tool_a = tools.get("tool_a").unwrap();
313
314        let result = combined
315            .call_tool("tool_a", serde_json::json!({}), &ctx, tool_a)
316            .await
317            .unwrap();
318
319        assert_eq!(result.as_text(), Some("A"));
320    }
321
322    #[tokio::test]
323    async fn test_combined_toolset_conflict_detection() {
324        let ts1 = FunctionToolset::new().with_id("ts1").tool(ToolA);
325        let ts2 = FunctionToolset::new().with_id("ts2").tool(ConflictingTool);
326
327        let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
328
329        let ctx = RunContext::minimal("test");
330        let result = combined.get_tools(&ctx).await;
331
332        assert!(result.is_err());
333        let err = result.unwrap_err();
334        assert!(err.message().contains("conflict"));
335        assert!(err.message().contains("tool_a"));
336    }
337
338    #[tokio::test]
339    async fn test_combined_toolset_enter_exit() {
340        use std::sync::atomic::{AtomicU32, Ordering};
341        use std::sync::Arc;
342
343        let enter_count = Arc::new(AtomicU32::new(0));
344        let exit_count = Arc::new(AtomicU32::new(0));
345
346        struct TrackedToolset {
347            enter_count: Arc<AtomicU32>,
348            exit_count: Arc<AtomicU32>,
349        }
350
351        #[async_trait]
352        impl AbstractToolset<()> for TrackedToolset {
353            fn id(&self) -> Option<&str> {
354                None
355            }
356
357            async fn get_tools(
358                &self,
359                _ctx: &RunContext<()>,
360            ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
361                Ok(HashMap::new())
362            }
363
364            async fn call_tool(
365                &self,
366                _name: &str,
367                _args: JsonValue,
368                _ctx: &RunContext<()>,
369                _tool: &ToolsetTool,
370            ) -> Result<ToolReturn, ToolError> {
371                Ok(ToolReturn::empty())
372            }
373
374            async fn enter(&self) -> Result<(), ToolError> {
375                self.enter_count.fetch_add(1, Ordering::SeqCst);
376                Ok(())
377            }
378
379            async fn exit(&self) -> Result<(), ToolError> {
380                self.exit_count.fetch_add(1, Ordering::SeqCst);
381                Ok(())
382            }
383        }
384
385        let ts1 = TrackedToolset {
386            enter_count: enter_count.clone(),
387            exit_count: exit_count.clone(),
388        };
389        let ts2 = TrackedToolset {
390            enter_count: enter_count.clone(),
391            exit_count: exit_count.clone(),
392        };
393
394        let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
395
396        combined.enter().await.unwrap();
397        assert_eq!(enter_count.load(Ordering::SeqCst), 2);
398
399        combined.exit().await.unwrap();
400        assert_eq!(exit_count.load(Ordering::SeqCst), 2);
401    }
402}