Skip to main content

serdes_ai_toolsets/
prepared.rs

1//! Prepared toolset implementation.
2//!
3//! This module provides `PreparedToolset`, which modifies tool definitions
4//! at runtime using a prepare function.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14/// Prepares/modifies tool definitions at runtime.
15///
16/// This allows dynamically modifying tool definitions based on the
17/// current context, such as adding dynamic descriptions or hiding
18/// tools based on user permissions.
19///
20/// # Example
21///
22/// ```ignore
23/// use serdes_ai_toolsets::{PreparedToolset, FunctionToolset};
24///
25/// let toolset = FunctionToolset::new().tool(admin_tool);
26///
27/// // Hide admin tools for non-admin users
28/// let prepared = PreparedToolset::new(toolset, |ctx, defs| {
29///     if ctx.deps.is_admin {
30///         Some(defs)
31///     } else {
32///         Some(defs.into_iter().filter(|d| !d.name.starts_with("admin_")).collect())
33///     }
34/// });
35/// ```
36pub struct PreparedToolset<T, F, Deps = ()> {
37    inner: T,
38    prepare_fn: F,
39    _phantom: PhantomData<fn() -> Deps>,
40}
41
42impl<T, F, Deps> PreparedToolset<T, F, Deps>
43where
44    T: AbstractToolset<Deps>,
45    F: Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync,
46{
47    /// Create a new prepared toolset.
48    pub fn new(inner: T, prepare_fn: F) -> Self {
49        Self {
50            inner,
51            prepare_fn,
52            _phantom: PhantomData,
53        }
54    }
55
56    /// Get the inner toolset.
57    #[must_use]
58    pub fn inner(&self) -> &T {
59        &self.inner
60    }
61}
62
63#[async_trait]
64impl<T, F, Deps> AbstractToolset<Deps> for PreparedToolset<T, F, Deps>
65where
66    T: AbstractToolset<Deps>,
67    F: Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync,
68    Deps: Send + Sync,
69{
70    fn id(&self) -> Option<&str> {
71        self.inner.id()
72    }
73
74    fn type_name(&self) -> &'static str {
75        "PreparedToolset"
76    }
77
78    fn label(&self) -> String {
79        format!("PreparedToolset({})", self.inner.label())
80    }
81
82    async fn get_tools(
83        &self,
84        ctx: &RunContext<Deps>,
85    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
86        let inner_tools = self.inner.get_tools(ctx).await?;
87
88        // Extract definitions for the prepare function
89        let defs: Vec<ToolDefinition> = inner_tools.values().map(|t| t.tool_def.clone()).collect();
90
91        // Apply the prepare function
92        let prepared_defs = match (self.prepare_fn)(ctx, defs) {
93            Some(defs) => defs,
94            None => return Ok(HashMap::new()), // Return empty if prepare returns None
95        };
96
97        // Build result, keeping only tools that are in the prepared definitions
98        let prepared_names: std::collections::HashSet<_> =
99            prepared_defs.iter().map(|d| d.name.clone()).collect();
100
101        // Create a map of prepared definitions
102        let def_map: HashMap<String, ToolDefinition> = prepared_defs
103            .into_iter()
104            .map(|d| (d.name.clone(), d))
105            .collect();
106
107        Ok(inner_tools
108            .into_iter()
109            .filter(|(name, _)| prepared_names.contains(name))
110            .map(|(name, mut tool)| {
111                // Update with the potentially modified definition
112                if let Some(prepared_def) = def_map.get(&name) {
113                    tool.tool_def = prepared_def.clone();
114                }
115                (name, tool)
116            })
117            .collect())
118    }
119
120    async fn call_tool(
121        &self,
122        name: &str,
123        args: JsonValue,
124        ctx: &RunContext<Deps>,
125        tool: &ToolsetTool,
126    ) -> Result<ToolReturn, ToolError> {
127        self.inner.call_tool(name, args, ctx, tool).await
128    }
129
130    async fn enter(&self) -> Result<(), ToolError> {
131        self.inner.enter().await
132    }
133
134    async fn exit(&self) -> Result<(), ToolError> {
135        self.inner.exit().await
136    }
137}
138
139impl<T: std::fmt::Debug, F, Deps> std::fmt::Debug for PreparedToolset<T, F, Deps> {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("PreparedToolset")
142            .field("inner", &self.inner)
143            .finish()
144    }
145}
146
147/// Common prepare functions.
148pub mod preparers {
149    use serdes_ai_tools::{RunContext, ToolDefinition};
150
151    /// Add a suffix to all tool descriptions.
152    pub fn add_description_suffix<Deps>(
153        suffix: &str,
154    ) -> impl Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync + '_
155    {
156        move |_, defs| {
157            Some(
158                defs.into_iter()
159                    .map(|mut d| {
160                        d.description = format!("{} {}", d.description, suffix);
161                        d
162                    })
163                    .collect(),
164            )
165        }
166    }
167
168    /// Filter tools based on a predicate.
169    pub fn filter<Deps, F>(
170        pred: F,
171    ) -> impl Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync
172    where
173        F: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
174    {
175        move |ctx, defs| Some(defs.into_iter().filter(|d| pred(ctx, d)).collect())
176    }
177
178    /// Sort tools by name.
179    pub fn sort_by_name<Deps>(
180    ) -> impl Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync
181    {
182        |_, mut defs| {
183            defs.sort_by(|a, b| a.name.cmp(&b.name));
184            Some(defs)
185        }
186    }
187
188    /// Limit the number of tools.
189    pub fn limit<Deps>(
190        max: usize,
191    ) -> impl Fn(&RunContext<Deps>, Vec<ToolDefinition>) -> Option<Vec<ToolDefinition>> + Send + Sync
192    {
193        move |_, defs| Some(defs.into_iter().take(max).collect())
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::FunctionToolset;
201    use async_trait::async_trait;
202    use serdes_ai_tools::Tool;
203
204    struct ToolA;
205
206    #[async_trait]
207    impl Tool<()> for ToolA {
208        fn definition(&self) -> ToolDefinition {
209            ToolDefinition::new("tool_a", "Tool A")
210        }
211
212        async fn call(
213            &self,
214            _ctx: &RunContext<()>,
215            _args: JsonValue,
216        ) -> Result<ToolReturn, ToolError> {
217            Ok(ToolReturn::text("A"))
218        }
219    }
220
221    struct AdminTool;
222
223    #[async_trait]
224    impl Tool<()> for AdminTool {
225        fn definition(&self) -> ToolDefinition {
226            ToolDefinition::new("admin_delete", "Admin delete")
227        }
228
229        async fn call(
230            &self,
231            _ctx: &RunContext<()>,
232            _args: JsonValue,
233        ) -> Result<ToolReturn, ToolError> {
234            Ok(ToolReturn::text("Deleted"))
235        }
236    }
237
238    #[tokio::test]
239    async fn test_prepared_toolset_filter() {
240        let toolset = FunctionToolset::new().tool(ToolA).tool(AdminTool);
241
242        // Hide admin tools
243        let prepared = PreparedToolset::new(toolset, |_, defs| {
244            Some(
245                defs.into_iter()
246                    .filter(|d| !d.name.starts_with("admin_"))
247                    .collect(),
248            )
249        });
250
251        let ctx = RunContext::minimal("test");
252        let tools = prepared.get_tools(&ctx).await.unwrap();
253
254        assert_eq!(tools.len(), 1);
255        assert!(tools.contains_key("tool_a"));
256        assert!(!tools.contains_key("admin_delete"));
257    }
258
259    #[tokio::test]
260    async fn test_prepared_toolset_modify_description() {
261        let toolset = FunctionToolset::new().tool(ToolA);
262
263        let prepared = PreparedToolset::new(toolset, |_, defs| {
264            Some(
265                defs.into_iter()
266                    .map(|mut d| {
267                        d.description = format!("[MODIFIED] {}", d.description);
268                        d
269                    })
270                    .collect(),
271            )
272        });
273
274        let ctx = RunContext::minimal("test");
275        let tools = prepared.get_tools(&ctx).await.unwrap();
276
277        let tool = tools.get("tool_a").unwrap();
278        assert!(tool.tool_def.description.starts_with("[MODIFIED]"));
279    }
280
281    #[tokio::test]
282    async fn test_prepared_toolset_returns_none() {
283        let toolset = FunctionToolset::new().tool(ToolA);
284
285        // Return None to hide all tools
286        let prepared = PreparedToolset::new(toolset, |_, _| None);
287
288        let ctx = RunContext::minimal("test");
289        let tools = prepared.get_tools(&ctx).await.unwrap();
290
291        assert!(tools.is_empty());
292    }
293
294    #[tokio::test]
295    async fn test_prepared_toolset_call_still_works() {
296        let toolset = FunctionToolset::new().tool(ToolA);
297
298        let prepared = PreparedToolset::new(toolset, |_, defs| Some(defs));
299
300        let ctx = RunContext::minimal("test");
301        let tools = prepared.get_tools(&ctx).await.unwrap();
302        let tool = tools.get("tool_a").unwrap();
303
304        let result = prepared
305            .call_tool("tool_a", serde_json::json!({}), &ctx, tool)
306            .await
307            .unwrap();
308
309        assert_eq!(result.as_text(), Some("A"));
310    }
311}