1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolError, ToolReturn};
9use std::collections::HashMap;
10
11use crate::{AbstractToolset, BoxedToolset, ToolsetTool};
12
13pub struct CombinedToolset<Deps = ()> {
31 id: Option<String>,
32 toolsets: Vec<BoxedToolset<Deps>>,
33}
34
35impl<Deps> CombinedToolset<Deps> {
36 #[must_use]
38 pub fn new() -> Self {
39 Self {
40 id: None,
41 toolsets: Vec::new(),
42 }
43 }
44
45 #[must_use]
47 pub fn with_id(id: impl Into<String>) -> Self {
48 Self {
49 id: Some(id.into()),
50 toolsets: Vec::new(),
51 }
52 }
53
54 #[must_use]
56 pub fn with_toolset<T: AbstractToolset<Deps> + 'static>(mut self, toolset: T) -> Self {
57 self.toolsets.push(Box::new(toolset));
58 self
59 }
60
61 #[must_use]
63 pub fn add_boxed(mut self, toolset: BoxedToolset<Deps>) -> Self {
64 self.toolsets.push(toolset);
65 self
66 }
67
68 #[must_use]
70 pub fn toolsets<I, T>(mut self, toolsets: I) -> Self
71 where
72 I: IntoIterator<Item = T>,
73 T: AbstractToolset<Deps> + 'static,
74 {
75 for toolset in toolsets {
76 self.toolsets.push(Box::new(toolset));
77 }
78 self
79 }
80
81 #[must_use]
83 pub fn toolset_count(&self) -> usize {
84 self.toolsets.len()
85 }
86
87 #[must_use]
89 pub fn is_empty(&self) -> bool {
90 self.toolsets.is_empty()
91 }
92}
93
94impl<Deps> Default for CombinedToolset<Deps> {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100#[derive(Clone)]
102struct ToolOwnership {
103 toolset_index: usize,
104 tool: ToolsetTool,
105}
106
107#[async_trait]
108impl<Deps: Send + Sync + 'static> AbstractToolset<Deps> for CombinedToolset<Deps> {
109 fn id(&self) -> Option<&str> {
110 self.id.as_deref()
111 }
112
113 fn type_name(&self) -> &'static str {
114 "CombinedToolset"
115 }
116
117 fn tool_name_conflict_hint(&self) -> String {
118 "Use PrefixedToolset to add prefixes to tool names from different toolsets.".to_string()
119 }
120
121 async fn get_tools(
122 &self,
123 ctx: &RunContext<Deps>,
124 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
125 let mut all_tools: HashMap<String, ToolOwnership> = HashMap::new();
126 let mut conflicts: Vec<(String, String, String)> = Vec::new();
127
128 for (idx, toolset) in self.toolsets.iter().enumerate() {
129 let tools = toolset.get_tools(ctx).await?;
130
131 for (name, tool) in tools {
132 if let Some(existing) = all_tools.get(&name) {
133 let existing_label = self.toolsets[existing.toolset_index].label();
135 let new_label = toolset.label();
136 conflicts.push((name.clone(), existing_label, new_label));
137 } else {
138 all_tools.insert(
139 name,
140 ToolOwnership {
141 toolset_index: idx,
142 tool,
143 },
144 );
145 }
146 }
147 }
148
149 if !conflicts.is_empty() {
150 let conflict_msgs: Vec<String> = conflicts
151 .iter()
152 .map(|(name, t1, t2)| format!(" - '{}' exists in {} and {}", name, t1, t2))
153 .collect();
154
155 return Err(ToolError::execution_failed(format!(
156 "Tool name conflicts in {}:\n{}\n\nHint: {}",
157 self.label(),
158 conflict_msgs.join("\n"),
159 self.tool_name_conflict_hint()
160 )));
161 }
162
163 Ok(all_tools
164 .into_iter()
165 .map(|(name, ownership)| (name, ownership.tool))
166 .collect())
167 }
168
169 async fn call_tool(
170 &self,
171 name: &str,
172 args: JsonValue,
173 ctx: &RunContext<Deps>,
174 tool: &ToolsetTool,
175 ) -> Result<ToolReturn, ToolError> {
176 for toolset in &self.toolsets {
178 let tools = toolset.get_tools(ctx).await?;
179 if tools.contains_key(name) {
180 return toolset.call_tool(name, args, ctx, tool).await;
181 }
182 }
183
184 Err(ToolError::not_found(format!(
185 "Tool '{}' not found in {}",
186 name,
187 self.label()
188 )))
189 }
190
191 async fn enter(&self) -> Result<(), ToolError> {
192 for toolset in &self.toolsets {
193 toolset.enter().await?;
194 }
195 Ok(())
196 }
197
198 async fn exit(&self) -> Result<(), ToolError> {
199 for toolset in self.toolsets.iter().rev() {
201 toolset.exit().await?;
202 }
203 Ok(())
204 }
205}
206
207impl<Deps> std::fmt::Debug for CombinedToolset<Deps> {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 f.debug_struct("CombinedToolset")
210 .field("id", &self.id)
211 .field("toolset_count", &self.toolsets.len())
212 .finish()
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::FunctionToolset;
220 use async_trait::async_trait;
221 use serdes_ai_tools::{Tool, ToolDefinition};
222
223 struct ToolA;
224
225 #[async_trait]
226 impl Tool<()> for ToolA {
227 fn definition(&self) -> ToolDefinition {
228 ToolDefinition::new("tool_a", "Tool A")
229 }
230
231 async fn call(
232 &self,
233 _ctx: &RunContext<()>,
234 _args: JsonValue,
235 ) -> Result<ToolReturn, ToolError> {
236 Ok(ToolReturn::text("A"))
237 }
238 }
239
240 struct ToolB;
241
242 #[async_trait]
243 impl Tool<()> for ToolB {
244 fn definition(&self) -> ToolDefinition {
245 ToolDefinition::new("tool_b", "Tool B")
246 }
247
248 async fn call(
249 &self,
250 _ctx: &RunContext<()>,
251 _args: JsonValue,
252 ) -> Result<ToolReturn, ToolError> {
253 Ok(ToolReturn::text("B"))
254 }
255 }
256
257 struct ConflictingTool;
259
260 #[async_trait]
261 impl Tool<()> for ConflictingTool {
262 fn definition(&self) -> ToolDefinition {
263 ToolDefinition::new("tool_a", "Conflicting Tool A") }
265
266 async fn call(
267 &self,
268 _ctx: &RunContext<()>,
269 _args: JsonValue,
270 ) -> Result<ToolReturn, ToolError> {
271 Ok(ToolReturn::text("Conflict"))
272 }
273 }
274
275 #[test]
276 fn test_combined_toolset_new() {
277 let toolset = CombinedToolset::<()>::new();
278 assert!(toolset.is_empty());
279 assert_eq!(toolset.toolset_count(), 0);
280 }
281
282 #[test]
283 fn test_combined_toolset_with_id() {
284 let toolset = CombinedToolset::<()>::with_id("combined");
285 assert_eq!(toolset.id(), Some("combined"));
286 }
287
288 #[tokio::test]
289 async fn test_combined_toolset_merges_tools() {
290 let ts1 = FunctionToolset::new().with_id("ts1").tool(ToolA);
291 let ts2 = FunctionToolset::new().with_id("ts2").tool(ToolB);
292
293 let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
294
295 let ctx = RunContext::minimal("test");
296 let tools = combined.get_tools(&ctx).await.unwrap();
297
298 assert_eq!(tools.len(), 2);
299 assert!(tools.contains_key("tool_a"));
300 assert!(tools.contains_key("tool_b"));
301 }
302
303 #[tokio::test]
304 async fn test_combined_toolset_call_tool() {
305 let ts1 = FunctionToolset::new().tool(ToolA);
306 let ts2 = FunctionToolset::new().tool(ToolB);
307
308 let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
309
310 let ctx = RunContext::minimal("test");
311 let tools = combined.get_tools(&ctx).await.unwrap();
312 let tool_a = tools.get("tool_a").unwrap();
313
314 let result = combined
315 .call_tool("tool_a", serde_json::json!({}), &ctx, tool_a)
316 .await
317 .unwrap();
318
319 assert_eq!(result.as_text(), Some("A"));
320 }
321
322 #[tokio::test]
323 async fn test_combined_toolset_conflict_detection() {
324 let ts1 = FunctionToolset::new().with_id("ts1").tool(ToolA);
325 let ts2 = FunctionToolset::new().with_id("ts2").tool(ConflictingTool);
326
327 let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
328
329 let ctx = RunContext::minimal("test");
330 let result = combined.get_tools(&ctx).await;
331
332 assert!(result.is_err());
333 let err = result.unwrap_err();
334 assert!(err.message().contains("conflict"));
335 assert!(err.message().contains("tool_a"));
336 }
337
338 #[tokio::test]
339 async fn test_combined_toolset_enter_exit() {
340 use std::sync::atomic::{AtomicU32, Ordering};
341 use std::sync::Arc;
342
343 let enter_count = Arc::new(AtomicU32::new(0));
344 let exit_count = Arc::new(AtomicU32::new(0));
345
346 struct TrackedToolset {
347 enter_count: Arc<AtomicU32>,
348 exit_count: Arc<AtomicU32>,
349 }
350
351 #[async_trait]
352 impl AbstractToolset<()> for TrackedToolset {
353 fn id(&self) -> Option<&str> {
354 None
355 }
356
357 async fn get_tools(
358 &self,
359 _ctx: &RunContext<()>,
360 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
361 Ok(HashMap::new())
362 }
363
364 async fn call_tool(
365 &self,
366 _name: &str,
367 _args: JsonValue,
368 _ctx: &RunContext<()>,
369 _tool: &ToolsetTool,
370 ) -> Result<ToolReturn, ToolError> {
371 Ok(ToolReturn::empty())
372 }
373
374 async fn enter(&self) -> Result<(), ToolError> {
375 self.enter_count.fetch_add(1, Ordering::SeqCst);
376 Ok(())
377 }
378
379 async fn exit(&self) -> Result<(), ToolError> {
380 self.exit_count.fetch_add(1, Ordering::SeqCst);
381 Ok(())
382 }
383 }
384
385 let ts1 = TrackedToolset {
386 enter_count: enter_count.clone(),
387 exit_count: exit_count.clone(),
388 };
389 let ts2 = TrackedToolset {
390 enter_count: enter_count.clone(),
391 exit_count: exit_count.clone(),
392 };
393
394 let combined = CombinedToolset::new().with_toolset(ts1).with_toolset(ts2);
395
396 combined.enter().await.unwrap();
397 assert_eq!(enter_count.load(Ordering::SeqCst), 2);
398
399 combined.exit().await.unwrap();
400 assert_eq!(exit_count.load(Ordering::SeqCst), 2);
401 }
402}