1use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::BoxFuture;
7use crate::error::SynwireError;
8use crate::tools::traits::{Tool, ToolProvider, validate_tool_name};
9use crate::tools::types::{ToolOutput, ToolSchema};
10
11type ToolFn = Arc<
13 dyn Fn(serde_json::Value) -> BoxFuture<'static, Result<ToolOutput, SynwireError>> + Send + Sync,
14>;
15
16pub struct StructuredTool {
46 name: String,
47 description: String,
48 schema: ToolSchema,
49 func: ToolFn,
50}
51
52impl std::fmt::Debug for StructuredTool {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("StructuredTool")
55 .field("name", &self.name)
56 .field("description", &self.description)
57 .field("schema", &self.schema)
58 .field("func", &"<closure>")
59 .finish()
60 }
61}
62
63impl StructuredTool {
64 pub fn builder() -> StructuredToolBuilder {
66 StructuredToolBuilder {
67 name: None,
68 description: None,
69 schema: None,
70 func: None,
71 }
72 }
73}
74
75impl Tool for StructuredTool {
76 fn name(&self) -> &str {
77 &self.name
78 }
79
80 fn description(&self) -> &str {
81 &self.description
82 }
83
84 fn schema(&self) -> &ToolSchema {
85 &self.schema
86 }
87
88 fn invoke(&self, input: serde_json::Value) -> BoxFuture<'_, Result<ToolOutput, SynwireError>> {
89 (self.func)(input)
90 }
91}
92
93#[derive(Default)]
98pub struct StructuredToolBuilder {
99 name: Option<String>,
100 description: Option<String>,
101 schema: Option<ToolSchema>,
102 func: Option<ToolFn>,
103}
104
105impl StructuredToolBuilder {
106 #[must_use]
108 pub fn name(mut self, name: impl Into<String>) -> Self {
109 self.name = Some(name.into());
110 self
111 }
112
113 #[must_use]
115 pub fn description(mut self, description: impl Into<String>) -> Self {
116 self.description = Some(description.into());
117 self
118 }
119
120 #[must_use]
122 pub fn schema(mut self, schema: ToolSchema) -> Self {
123 self.schema = Some(schema);
124 self
125 }
126
127 #[must_use]
129 pub fn func<F>(mut self, f: F) -> Self
130 where
131 F: Fn(serde_json::Value) -> BoxFuture<'static, Result<ToolOutput, SynwireError>>
132 + Send
133 + Sync
134 + 'static,
135 {
136 self.func = Some(Arc::new(f));
137 self
138 }
139
140 pub fn build(self) -> Result<StructuredTool, SynwireError> {
148 let name = self.name.ok_or_else(|| {
149 SynwireError::Tool(crate::error::ToolError::ValidationFailed {
150 message: "tool name is required".into(),
151 })
152 })?;
153 let description = self.description.ok_or_else(|| {
154 SynwireError::Tool(crate::error::ToolError::ValidationFailed {
155 message: "tool description is required".into(),
156 })
157 })?;
158 let schema = self.schema.ok_or_else(|| {
159 SynwireError::Tool(crate::error::ToolError::ValidationFailed {
160 message: "tool schema is required".into(),
161 })
162 })?;
163 let func = self.func.ok_or_else(|| {
164 SynwireError::Tool(crate::error::ToolError::ValidationFailed {
165 message: "tool function is required".into(),
166 })
167 })?;
168
169 validate_tool_name(&name)?;
170
171 Ok(StructuredTool {
172 name,
173 description,
174 schema,
175 func,
176 })
177 }
178}
179
180#[cfg(test)]
181#[allow(clippy::unwrap_used)]
182mod tests {
183 use super::*;
184
185 fn make_schema(name: &str) -> ToolSchema {
186 ToolSchema {
187 name: name.into(),
188 description: "test".into(),
189 parameters: serde_json::json!({"type": "object"}),
190 }
191 }
192
193 fn make_echo_func()
194 -> impl Fn(serde_json::Value) -> BoxFuture<'static, Result<ToolOutput, SynwireError>> + Send + Sync
195 {
196 |input| {
197 Box::pin(async move {
198 Ok(ToolOutput {
199 content: input.to_string(),
200 artifact: None,
201 binary_results: Vec::new(),
202 status: crate::tools::ToolResultStatus::Success,
203 telemetry: None,
204 content_type: None,
205 })
206 })
207 }
208 }
209
210 #[tokio::test]
211 async fn structured_tool_invoke_valid_input() {
212 let tool = StructuredTool::builder()
213 .name("echo")
214 .description("echoes input")
215 .schema(make_schema("echo"))
216 .func(make_echo_func())
217 .build()
218 .unwrap();
219
220 let result = tool
221 .invoke(serde_json::json!({"msg": "hello"}))
222 .await
223 .unwrap();
224 assert!(result.content.contains("hello"));
225 }
226
227 #[test]
228 fn schema_is_serialisable() {
229 let tool = StructuredTool::builder()
230 .name("my-tool")
231 .description("a tool")
232 .schema(make_schema("my-tool"))
233 .func(make_echo_func())
234 .build()
235 .unwrap();
236
237 let json = serde_json::to_value(tool.schema()).unwrap();
238 assert_eq!(json["name"], "my-tool");
239 }
240
241 #[tokio::test]
242 async fn invoke_with_error_func() {
243 let tool = StructuredTool::builder()
244 .name("fail-tool")
245 .description("always fails")
246 .schema(make_schema("fail-tool"))
247 .func(|_input| {
248 Box::pin(async {
249 Err(SynwireError::Tool(
250 crate::error::ToolError::InvocationFailed {
251 message: "boom".into(),
252 },
253 ))
254 })
255 })
256 .build()
257 .unwrap();
258
259 let result = tool.invoke(serde_json::json!({})).await;
260 assert!(result.is_err());
261 assert!(result.unwrap_err().to_string().contains("boom"));
262 }
263
264 #[test]
265 fn builder_rejects_invalid_name() {
266 let result = StructuredTool::builder()
267 .name("bad name!")
268 .description("d")
269 .schema(make_schema("bad name!"))
270 .func(make_echo_func())
271 .build();
272 assert!(result.is_err());
273 }
274
275 #[test]
276 fn builder_requires_all_fields() {
277 let result = StructuredTool::builder()
279 .name("ok")
280 .description("d")
281 .schema(make_schema("ok"))
282 .build();
283 assert!(result.is_err());
284 assert!(result.unwrap_err().to_string().contains("function"));
285 }
286}
287
288pub struct StaticToolProvider {
296 tools: Vec<Arc<dyn Tool>>,
297}
298
299impl std::fmt::Debug for StaticToolProvider {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 f.debug_struct("StaticToolProvider")
302 .field("tools_count", &self.tools.len())
303 .finish()
304 }
305}
306
307impl StaticToolProvider {
308 #[must_use]
310 pub fn new(tools: Vec<Box<dyn Tool>>) -> Self {
311 Self {
312 tools: tools.into_iter().map(Arc::from).collect(),
313 }
314 }
315
316 #[must_use]
318 pub fn from_arcs(tools: Vec<Arc<dyn Tool>>) -> Self {
319 Self { tools }
320 }
321}
322
323impl ToolProvider for StaticToolProvider {
324 fn discover_tools(&self) -> BoxFuture<'_, Result<Vec<Arc<dyn Tool>>, SynwireError>> {
325 let tools = self.tools.clone();
326 Box::pin(async move { Ok(tools) })
327 }
328
329 fn get_tool(&self, name: &str) -> BoxFuture<'_, Result<Option<Arc<dyn Tool>>, SynwireError>> {
330 let found = self.tools.iter().find(|t| t.name() == name).cloned();
331 Box::pin(async move { Ok(found) })
332 }
333}
334
335#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
341#[non_exhaustive]
342pub enum NameCollisionPolicy {
343 #[default]
345 KeepFirst,
346 KeepLast,
348 Error,
350}
351
352pub struct CompositeToolProvider {
361 providers: Vec<Box<dyn ToolProvider>>,
362 collision_policy: NameCollisionPolicy,
363}
364
365impl std::fmt::Debug for CompositeToolProvider {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 f.debug_struct("CompositeToolProvider")
368 .field("providers_count", &self.providers.len())
369 .field("collision_policy", &self.collision_policy)
370 .finish()
371 }
372}
373
374impl CompositeToolProvider {
375 #[must_use]
378 pub fn new(
379 providers: Vec<Box<dyn ToolProvider>>,
380 collision_policy: NameCollisionPolicy,
381 ) -> Self {
382 Self {
383 providers,
384 collision_policy,
385 }
386 }
387
388 #[must_use]
390 pub fn with_keep_first(providers: Vec<Box<dyn ToolProvider>>) -> Self {
391 Self::new(providers, NameCollisionPolicy::KeepFirst)
392 }
393}
394
395impl ToolProvider for CompositeToolProvider {
396 fn discover_tools(&self) -> BoxFuture<'_, Result<Vec<Arc<dyn Tool>>, SynwireError>> {
397 Box::pin(async move {
398 let mut map: HashMap<String, Arc<dyn Tool>> = HashMap::new();
399 let mut ordered: Vec<Arc<dyn Tool>> = Vec::new();
400
401 for provider in &self.providers {
402 let tools = provider.discover_tools().await?;
403 for tool in tools {
404 let name = tool.name().to_owned();
405 match self.collision_policy {
406 NameCollisionPolicy::KeepFirst => {
407 if !map.contains_key(&name) {
408 let _ = map.insert(name.clone(), Arc::clone(&tool));
409 ordered.push(tool);
410 }
411 }
412 NameCollisionPolicy::KeepLast => {
413 if let Some(pos) = ordered.iter().position(|t| t.name() == name) {
414 ordered[pos] = Arc::clone(&tool);
415 } else {
416 ordered.push(Arc::clone(&tool));
417 }
418 let _ = map.insert(name, tool);
419 }
420 NameCollisionPolicy::Error => {
421 if map.contains_key(&name) {
422 return Err(SynwireError::Tool(
423 crate::error::ToolError::ValidationFailed {
424 message: format!(
425 "CompositeToolProvider: name collision for tool '{name}'"
426 ),
427 },
428 ));
429 }
430 let _ = map.insert(name, Arc::clone(&tool));
431 ordered.push(tool);
432 }
433 }
434 }
435 }
436
437 Ok(ordered)
438 })
439 }
440
441 fn get_tool(&self, name: &str) -> BoxFuture<'_, Result<Option<Arc<dyn Tool>>, SynwireError>> {
442 let name = name.to_owned();
443 Box::pin(async move {
444 for provider in &self.providers {
445 if let Some(tool) = provider.get_tool(&name).await? {
446 return Ok(Some(tool));
447 }
448 }
449 Ok(None)
450 })
451 }
452}
453
454#[cfg(test)]
455#[allow(clippy::unwrap_used, clippy::panic)]
456mod provider_tests {
457 use super::*;
458
459 fn make_tool(name: &str) -> Box<dyn Tool> {
460 StructuredTool::builder()
461 .name(name)
462 .description(name)
463 .schema(ToolSchema {
464 name: name.into(),
465 description: name.into(),
466 parameters: serde_json::json!({"type": "object"}),
467 })
468 .func(|_| Box::pin(async { Ok(ToolOutput::default()) }))
469 .build()
470 .map(|t| Box::new(t) as Box<dyn Tool>)
471 .unwrap()
472 }
473
474 #[tokio::test]
475 async fn static_provider_discovers_all_tools() {
476 let provider = StaticToolProvider::new(vec![make_tool("a"), make_tool("b")]);
477 let tools = provider.discover_tools().await.unwrap();
478 assert_eq!(tools.len(), 2);
479 }
480
481 #[tokio::test]
482 async fn static_provider_get_by_name() {
483 let provider = StaticToolProvider::new(vec![make_tool("search")]);
484 let tool = provider.get_tool("search").await.unwrap();
485 assert!(tool.is_some());
486 let missing = provider.get_tool("missing").await.unwrap();
487 assert!(missing.is_none());
488 }
489
490 #[tokio::test]
491 async fn composite_keep_first_deduplicates() {
492 let p1 = Box::new(StaticToolProvider::new(vec![make_tool("x")]));
493 let p2 = Box::new(StaticToolProvider::new(vec![
494 make_tool("x"),
495 make_tool("y"),
496 ]));
497 let composite = CompositeToolProvider::with_keep_first(vec![p1, p2]);
498 let tools = composite.discover_tools().await.unwrap();
499 assert_eq!(tools.len(), 2);
500 let names: Vec<_> = tools.iter().map(|t| t.name()).collect();
501 assert!(names.contains(&"x"));
502 assert!(names.contains(&"y"));
503 }
504
505 #[tokio::test]
506 async fn composite_error_policy_on_collision() {
507 let p1 = Box::new(StaticToolProvider::new(vec![make_tool("dup")]));
508 let p2 = Box::new(StaticToolProvider::new(vec![make_tool("dup")]));
509 let composite = CompositeToolProvider::new(vec![p1, p2], NameCollisionPolicy::Error);
510 let result = composite.discover_tools().await;
511 match result {
513 Err(e) => assert!(e.to_string().contains("collision")),
514 Ok(_) => panic!("expected a collision error"),
515 }
516 }
517}