Skip to main content

synwire_core/tools/
structured.rs

1//! Structured tool with builder pattern, and tool provider implementations.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::BoxFuture;
7use crate::error::SynwireError;
8use crate::tools::traits::{Tool, ToolProvider, validate_tool_name};
9use crate::tools::types::{ToolOutput, ToolSchema};
10
11/// Type alias for the tool function closure.
12type ToolFn = Arc<
13    dyn Fn(serde_json::Value) -> BoxFuture<'static, Result<ToolOutput, SynwireError>> + Send + Sync,
14>;
15
16/// A structured tool with a typed schema and a closure for execution.
17///
18/// Use [`StructuredToolBuilder`] to construct instances.
19///
20/// # Example
21///
22/// ```
23/// use synwire_core::tools::{StructuredTool, Tool, ToolOutput, ToolSchema};
24/// use synwire_core::error::SynwireError;
25///
26/// let tool = StructuredTool::builder()
27///     .name("echo")
28///     .description("Echoes input")
29///     .schema(ToolSchema {
30///         name: "echo".into(),
31///         description: "Echoes input".into(),
32///         parameters: serde_json::json!({"type": "object"}),
33///     })
34///     .func(|input| Box::pin(async move {
35///         Ok(ToolOutput {
36///             content: input.to_string(),
37///             ..Default::default()
38///         })
39///     }))
40///     .build()
41///     .expect("valid tool");
42///
43/// assert_eq!(tool.name(), "echo");
44/// ```
45pub struct StructuredTool {
46    name: String,
47    description: String,
48    schema: ToolSchema,
49    func: ToolFn,
50}
51
52impl std::fmt::Debug for StructuredTool {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("StructuredTool")
55            .field("name", &self.name)
56            .field("description", &self.description)
57            .field("schema", &self.schema)
58            .field("func", &"<closure>")
59            .finish()
60    }
61}
62
63impl StructuredTool {
64    /// Returns a new builder for constructing a `StructuredTool`.
65    pub fn builder() -> StructuredToolBuilder {
66        StructuredToolBuilder {
67            name: None,
68            description: None,
69            schema: None,
70            func: None,
71        }
72    }
73}
74
75impl Tool for StructuredTool {
76    fn name(&self) -> &str {
77        &self.name
78    }
79
80    fn description(&self) -> &str {
81        &self.description
82    }
83
84    fn schema(&self) -> &ToolSchema {
85        &self.schema
86    }
87
88    fn invoke(&self, input: serde_json::Value) -> BoxFuture<'_, Result<ToolOutput, SynwireError>> {
89        (self.func)(input)
90    }
91}
92
93/// Builder for [`StructuredTool`].
94///
95/// All fields are required. The builder validates the tool name at
96/// [`build()`](Self::build) time.
97#[derive(Default)]
98pub struct StructuredToolBuilder {
99    name: Option<String>,
100    description: Option<String>,
101    schema: Option<ToolSchema>,
102    func: Option<ToolFn>,
103}
104
105impl StructuredToolBuilder {
106    /// Set the tool name.
107    #[must_use]
108    pub fn name(mut self, name: impl Into<String>) -> Self {
109        self.name = Some(name.into());
110        self
111    }
112
113    /// Set the tool description.
114    #[must_use]
115    pub fn description(mut self, description: impl Into<String>) -> Self {
116        self.description = Some(description.into());
117        self
118    }
119
120    /// Set the tool schema.
121    #[must_use]
122    pub fn schema(mut self, schema: ToolSchema) -> Self {
123        self.schema = Some(schema);
124        self
125    }
126
127    /// Set the tool function.
128    #[must_use]
129    pub fn func<F>(mut self, f: F) -> Self
130    where
131        F: Fn(serde_json::Value) -> BoxFuture<'static, Result<ToolOutput, SynwireError>>
132            + Send
133            + Sync
134            + 'static,
135    {
136        self.func = Some(Arc::new(f));
137        self
138    }
139
140    /// Build the structured tool.
141    ///
142    /// # Errors
143    ///
144    /// Returns [`SynwireError::Tool`] if:
145    /// - Any required field is missing (reported as a validation failure)
146    /// - The tool name fails validation
147    pub fn build(self) -> Result<StructuredTool, SynwireError> {
148        let name = self.name.ok_or_else(|| {
149            SynwireError::Tool(crate::error::ToolError::ValidationFailed {
150                message: "tool name is required".into(),
151            })
152        })?;
153        let description = self.description.ok_or_else(|| {
154            SynwireError::Tool(crate::error::ToolError::ValidationFailed {
155                message: "tool description is required".into(),
156            })
157        })?;
158        let schema = self.schema.ok_or_else(|| {
159            SynwireError::Tool(crate::error::ToolError::ValidationFailed {
160                message: "tool schema is required".into(),
161            })
162        })?;
163        let func = self.func.ok_or_else(|| {
164            SynwireError::Tool(crate::error::ToolError::ValidationFailed {
165                message: "tool function is required".into(),
166            })
167        })?;
168
169        validate_tool_name(&name)?;
170
171        Ok(StructuredTool {
172            name,
173            description,
174            schema,
175            func,
176        })
177    }
178}
179
180#[cfg(test)]
181#[allow(clippy::unwrap_used)]
182mod tests {
183    use super::*;
184
185    fn make_schema(name: &str) -> ToolSchema {
186        ToolSchema {
187            name: name.into(),
188            description: "test".into(),
189            parameters: serde_json::json!({"type": "object"}),
190        }
191    }
192
193    fn make_echo_func()
194    -> impl Fn(serde_json::Value) -> BoxFuture<'static, Result<ToolOutput, SynwireError>> + Send + Sync
195    {
196        |input| {
197            Box::pin(async move {
198                Ok(ToolOutput {
199                    content: input.to_string(),
200                    artifact: None,
201                    binary_results: Vec::new(),
202                    status: crate::tools::ToolResultStatus::Success,
203                    telemetry: None,
204                    content_type: None,
205                })
206            })
207        }
208    }
209
210    #[tokio::test]
211    async fn structured_tool_invoke_valid_input() {
212        let tool = StructuredTool::builder()
213            .name("echo")
214            .description("echoes input")
215            .schema(make_schema("echo"))
216            .func(make_echo_func())
217            .build()
218            .unwrap();
219
220        let result = tool
221            .invoke(serde_json::json!({"msg": "hello"}))
222            .await
223            .unwrap();
224        assert!(result.content.contains("hello"));
225    }
226
227    #[test]
228    fn schema_is_serialisable() {
229        let tool = StructuredTool::builder()
230            .name("my-tool")
231            .description("a tool")
232            .schema(make_schema("my-tool"))
233            .func(make_echo_func())
234            .build()
235            .unwrap();
236
237        let json = serde_json::to_value(tool.schema()).unwrap();
238        assert_eq!(json["name"], "my-tool");
239    }
240
241    #[tokio::test]
242    async fn invoke_with_error_func() {
243        let tool = StructuredTool::builder()
244            .name("fail-tool")
245            .description("always fails")
246            .schema(make_schema("fail-tool"))
247            .func(|_input| {
248                Box::pin(async {
249                    Err(SynwireError::Tool(
250                        crate::error::ToolError::InvocationFailed {
251                            message: "boom".into(),
252                        },
253                    ))
254                })
255            })
256            .build()
257            .unwrap();
258
259        let result = tool.invoke(serde_json::json!({})).await;
260        assert!(result.is_err());
261        assert!(result.unwrap_err().to_string().contains("boom"));
262    }
263
264    #[test]
265    fn builder_rejects_invalid_name() {
266        let result = StructuredTool::builder()
267            .name("bad name!")
268            .description("d")
269            .schema(make_schema("bad name!"))
270            .func(make_echo_func())
271            .build();
272        assert!(result.is_err());
273    }
274
275    #[test]
276    fn builder_requires_all_fields() {
277        // Missing func
278        let result = StructuredTool::builder()
279            .name("ok")
280            .description("d")
281            .schema(make_schema("ok"))
282            .build();
283        assert!(result.is_err());
284        assert!(result.unwrap_err().to_string().contains("function"));
285    }
286}
287
288// ---------------------------------------------------------------------------
289// StaticToolProvider
290// ---------------------------------------------------------------------------
291
292/// A [`ToolProvider`] backed by a fixed, pre-built list of tools.
293///
294/// Useful for registering tools known at construction time.
295pub struct StaticToolProvider {
296    tools: Vec<Arc<dyn Tool>>,
297}
298
299impl std::fmt::Debug for StaticToolProvider {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        f.debug_struct("StaticToolProvider")
302            .field("tools_count", &self.tools.len())
303            .finish()
304    }
305}
306
307impl StaticToolProvider {
308    /// Creates a new `StaticToolProvider` from a list of boxed tools.
309    #[must_use]
310    pub fn new(tools: Vec<Box<dyn Tool>>) -> Self {
311        Self {
312            tools: tools.into_iter().map(Arc::from).collect(),
313        }
314    }
315
316    /// Creates a new `StaticToolProvider` from a list of Arc'd tools.
317    #[must_use]
318    pub fn from_arcs(tools: Vec<Arc<dyn Tool>>) -> Self {
319        Self { tools }
320    }
321}
322
323impl ToolProvider for StaticToolProvider {
324    fn discover_tools(&self) -> BoxFuture<'_, Result<Vec<Arc<dyn Tool>>, SynwireError>> {
325        let tools = self.tools.clone();
326        Box::pin(async move { Ok(tools) })
327    }
328
329    fn get_tool(&self, name: &str) -> BoxFuture<'_, Result<Option<Arc<dyn Tool>>, SynwireError>> {
330        let found = self.tools.iter().find(|t| t.name() == name).cloned();
331        Box::pin(async move { Ok(found) })
332    }
333}
334
335// ---------------------------------------------------------------------------
336// NameCollisionPolicy
337// ---------------------------------------------------------------------------
338
339/// Policy for handling tool name collisions in a [`CompositeToolProvider`].
340#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
341#[non_exhaustive]
342pub enum NameCollisionPolicy {
343    /// Keep the first tool registered with a given name (ignore later duplicates).
344    #[default]
345    KeepFirst,
346    /// Keep the last tool registered with a given name.
347    KeepLast,
348    /// Return an error if a name collision occurs during discovery.
349    Error,
350}
351
352// ---------------------------------------------------------------------------
353// CompositeToolProvider
354// ---------------------------------------------------------------------------
355
356/// A [`ToolProvider`] that aggregates tools from multiple child providers.
357///
358/// Tools from all providers are merged according to the configured
359/// [`NameCollisionPolicy`].
360pub struct CompositeToolProvider {
361    providers: Vec<Box<dyn ToolProvider>>,
362    collision_policy: NameCollisionPolicy,
363}
364
365impl std::fmt::Debug for CompositeToolProvider {
366    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        f.debug_struct("CompositeToolProvider")
368            .field("providers_count", &self.providers.len())
369            .field("collision_policy", &self.collision_policy)
370            .finish()
371    }
372}
373
374impl CompositeToolProvider {
375    /// Creates a new `CompositeToolProvider` with the given providers and
376    /// collision policy.
377    #[must_use]
378    pub fn new(
379        providers: Vec<Box<dyn ToolProvider>>,
380        collision_policy: NameCollisionPolicy,
381    ) -> Self {
382        Self {
383            providers,
384            collision_policy,
385        }
386    }
387
388    /// Creates a `CompositeToolProvider` with [`NameCollisionPolicy::KeepFirst`].
389    #[must_use]
390    pub fn with_keep_first(providers: Vec<Box<dyn ToolProvider>>) -> Self {
391        Self::new(providers, NameCollisionPolicy::KeepFirst)
392    }
393}
394
395impl ToolProvider for CompositeToolProvider {
396    fn discover_tools(&self) -> BoxFuture<'_, Result<Vec<Arc<dyn Tool>>, SynwireError>> {
397        Box::pin(async move {
398            let mut map: HashMap<String, Arc<dyn Tool>> = HashMap::new();
399            let mut ordered: Vec<Arc<dyn Tool>> = Vec::new();
400
401            for provider in &self.providers {
402                let tools = provider.discover_tools().await?;
403                for tool in tools {
404                    let name = tool.name().to_owned();
405                    match self.collision_policy {
406                        NameCollisionPolicy::KeepFirst => {
407                            if !map.contains_key(&name) {
408                                let _ = map.insert(name.clone(), Arc::clone(&tool));
409                                ordered.push(tool);
410                            }
411                        }
412                        NameCollisionPolicy::KeepLast => {
413                            if let Some(pos) = ordered.iter().position(|t| t.name() == name) {
414                                ordered[pos] = Arc::clone(&tool);
415                            } else {
416                                ordered.push(Arc::clone(&tool));
417                            }
418                            let _ = map.insert(name, tool);
419                        }
420                        NameCollisionPolicy::Error => {
421                            if map.contains_key(&name) {
422                                return Err(SynwireError::Tool(
423                                    crate::error::ToolError::ValidationFailed {
424                                        message: format!(
425                                            "CompositeToolProvider: name collision for tool '{name}'"
426                                        ),
427                                    },
428                                ));
429                            }
430                            let _ = map.insert(name, Arc::clone(&tool));
431                            ordered.push(tool);
432                        }
433                    }
434                }
435            }
436
437            Ok(ordered)
438        })
439    }
440
441    fn get_tool(&self, name: &str) -> BoxFuture<'_, Result<Option<Arc<dyn Tool>>, SynwireError>> {
442        let name = name.to_owned();
443        Box::pin(async move {
444            for provider in &self.providers {
445                if let Some(tool) = provider.get_tool(&name).await? {
446                    return Ok(Some(tool));
447                }
448            }
449            Ok(None)
450        })
451    }
452}
453
454#[cfg(test)]
455#[allow(clippy::unwrap_used, clippy::panic)]
456mod provider_tests {
457    use super::*;
458
459    fn make_tool(name: &str) -> Box<dyn Tool> {
460        StructuredTool::builder()
461            .name(name)
462            .description(name)
463            .schema(ToolSchema {
464                name: name.into(),
465                description: name.into(),
466                parameters: serde_json::json!({"type": "object"}),
467            })
468            .func(|_| Box::pin(async { Ok(ToolOutput::default()) }))
469            .build()
470            .map(|t| Box::new(t) as Box<dyn Tool>)
471            .unwrap()
472    }
473
474    #[tokio::test]
475    async fn static_provider_discovers_all_tools() {
476        let provider = StaticToolProvider::new(vec![make_tool("a"), make_tool("b")]);
477        let tools = provider.discover_tools().await.unwrap();
478        assert_eq!(tools.len(), 2);
479    }
480
481    #[tokio::test]
482    async fn static_provider_get_by_name() {
483        let provider = StaticToolProvider::new(vec![make_tool("search")]);
484        let tool = provider.get_tool("search").await.unwrap();
485        assert!(tool.is_some());
486        let missing = provider.get_tool("missing").await.unwrap();
487        assert!(missing.is_none());
488    }
489
490    #[tokio::test]
491    async fn composite_keep_first_deduplicates() {
492        let p1 = Box::new(StaticToolProvider::new(vec![make_tool("x")]));
493        let p2 = Box::new(StaticToolProvider::new(vec![
494            make_tool("x"),
495            make_tool("y"),
496        ]));
497        let composite = CompositeToolProvider::with_keep_first(vec![p1, p2]);
498        let tools = composite.discover_tools().await.unwrap();
499        assert_eq!(tools.len(), 2);
500        let names: Vec<_> = tools.iter().map(|t| t.name()).collect();
501        assert!(names.contains(&"x"));
502        assert!(names.contains(&"y"));
503    }
504
505    #[tokio::test]
506    async fn composite_error_policy_on_collision() {
507        let p1 = Box::new(StaticToolProvider::new(vec![make_tool("dup")]));
508        let p2 = Box::new(StaticToolProvider::new(vec![make_tool("dup")]));
509        let composite = CompositeToolProvider::new(vec![p1, p2], NameCollisionPolicy::Error);
510        let result = composite.discover_tools().await;
511        // NOTE: unwrap_err() requires T: Debug; use match instead.
512        match result {
513            Err(e) => assert!(e.to_string().contains("collision")),
514            Ok(_) => panic!("expected a collision error"),
515        }
516    }
517}