Skip to main content

serdes_ai_toolsets/
external.rs

1//! External toolset implementation.
2//!
3//! This module provides `ExternalToolset`, for tools that are executed
4//! externally (not by the agent itself).
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14/// Toolset for externally-executed tools.
15///
16/// This is used when tools need to be exposed to the model but will be
17/// executed by an external system. When any tool is called, it returns
18/// `ToolError::CallDeferred` so the agent knows to defer execution.
19///
20/// # Example
21///
22/// ```ignore
23/// use serdes_ai_toolsets::ExternalToolset;
24/// use serdes_ai_tools::ToolDefinition;
25///
26/// let external = ExternalToolset::new()
27///     .with_id("external_api")
28///     .definition(ToolDefinition::new("api_call", "Call external API"));
29/// ```
30pub struct ExternalToolset<Deps = ()> {
31    id: Option<String>,
32    definitions: Vec<ToolDefinition>,
33    max_retries: u32,
34    _phantom: PhantomData<fn() -> Deps>,
35}
36
37impl<Deps> ExternalToolset<Deps> {
38    /// Create a new empty external toolset.
39    #[must_use]
40    pub fn new() -> Self {
41        Self {
42            id: None,
43            definitions: Vec::new(),
44            max_retries: 3,
45            _phantom: PhantomData,
46        }
47    }
48
49    /// Set the toolset ID.
50    #[must_use]
51    pub fn with_id(mut self, id: impl Into<String>) -> Self {
52        self.id = Some(id.into());
53        self
54    }
55
56    /// Set max retries.
57    #[must_use]
58    pub fn with_max_retries(mut self, retries: u32) -> Self {
59        self.max_retries = retries;
60        self
61    }
62
63    /// Add a tool definition.
64    #[must_use]
65    pub fn definition(mut self, def: ToolDefinition) -> Self {
66        self.definitions.push(def);
67        self
68    }
69
70    /// Add multiple tool definitions.
71    #[must_use]
72    pub fn definitions(mut self, defs: impl IntoIterator<Item = ToolDefinition>) -> Self {
73        self.definitions.extend(defs);
74        self
75    }
76
77    /// Get the number of definitions.
78    #[must_use]
79    pub fn len(&self) -> usize {
80        self.definitions.len()
81    }
82
83    /// Check if empty.
84    #[must_use]
85    pub fn is_empty(&self) -> bool {
86        self.definitions.is_empty()
87    }
88}
89
90impl<Deps> Default for ExternalToolset<Deps> {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96#[async_trait]
97impl<Deps: Send + Sync> AbstractToolset<Deps> for ExternalToolset<Deps> {
98    fn id(&self) -> Option<&str> {
99        self.id.as_deref()
100    }
101
102    fn type_name(&self) -> &'static str {
103        "ExternalToolset"
104    }
105
106    async fn get_tools(
107        &self,
108        _ctx: &RunContext<Deps>,
109    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
110        Ok(self
111            .definitions
112            .iter()
113            .map(|def| {
114                (
115                    def.name.clone(),
116                    ToolsetTool {
117                        toolset_id: self.id.clone(),
118                        tool_def: def.clone(),
119                        max_retries: self.max_retries,
120                    },
121                )
122            })
123            .collect())
124    }
125
126    async fn call_tool(
127        &self,
128        name: &str,
129        args: JsonValue,
130        _ctx: &RunContext<Deps>,
131        _tool: &ToolsetTool,
132    ) -> Result<ToolReturn, ToolError> {
133        // Always defer external tool calls
134        Err(ToolError::CallDeferred {
135            tool_name: name.to_string(),
136            args,
137        })
138    }
139}
140
141impl<Deps> std::fmt::Debug for ExternalToolset<Deps> {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        f.debug_struct("ExternalToolset")
144            .field("id", &self.id)
145            .field("definitions", &self.definitions.len())
146            .field("max_retries", &self.max_retries)
147            .finish()
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_external_toolset_new() {
157        let toolset = ExternalToolset::<()>::new();
158        assert!(toolset.is_empty());
159        assert!(toolset.id().is_none());
160    }
161
162    #[test]
163    fn test_external_toolset_with_definitions() {
164        let toolset = ExternalToolset::<()>::new()
165            .with_id("external")
166            .definition(ToolDefinition::new("api_call", "Call API"))
167            .definition(ToolDefinition::new("webhook", "Send webhook"));
168
169        assert_eq!(toolset.len(), 2);
170        assert_eq!(toolset.id(), Some("external"));
171    }
172
173    #[tokio::test]
174    async fn test_external_toolset_get_tools() {
175        let toolset =
176            ExternalToolset::<()>::new().definition(ToolDefinition::new("test", "Test tool"));
177
178        let ctx = RunContext::minimal("test");
179        let tools = toolset.get_tools(&ctx).await.unwrap();
180
181        assert_eq!(tools.len(), 1);
182        assert!(tools.contains_key("test"));
183    }
184
185    #[tokio::test]
186    async fn test_external_toolset_call_deferred() {
187        let toolset =
188            ExternalToolset::<()>::new().definition(ToolDefinition::new("api_call", "Call API"));
189
190        let ctx = RunContext::minimal("test");
191        let tools = toolset.get_tools(&ctx).await.unwrap();
192        let tool = tools.get("api_call").unwrap();
193
194        let result = toolset
195            .call_tool(
196                "api_call",
197                serde_json::json!({"endpoint": "/test"}),
198                &ctx,
199                tool,
200            )
201            .await;
202
203        assert!(matches!(result, Err(ToolError::CallDeferred { .. })));
204
205        if let Err(ToolError::CallDeferred { tool_name, args }) = result {
206            assert_eq!(tool_name, "api_call");
207            assert_eq!(args["endpoint"], "/test");
208        }
209    }
210}