Skip to main content

serdes_ai_toolsets/
filtered.rs

1//! Filtered toolset implementation.
2//!
3//! This module provides `FilteredToolset`, which wraps a toolset and
4//! filters tools based on a predicate function.
5
6use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14/// Filters tools from a toolset based on a predicate.
15///
16/// Only tools where the filter function returns `true` will be available.
17///
18/// # Example
19///
20/// ```ignore
21/// use serdes_ai_toolsets::{FilteredToolset, FunctionToolset};
22///
23/// let toolset = FunctionToolset::new()
24///     .tool(tool_a)
25///     .tool(dangerous_tool);
26///
27/// // Only allow non-dangerous tools
28/// let filtered = FilteredToolset::new(toolset, |_ctx, def| {
29///     !def.name.contains("dangerous")
30/// });
31/// ```
32pub struct FilteredToolset<T, F, Deps = ()> {
33    inner: T,
34    filter: F,
35    id: Option<String>,
36    _phantom: PhantomData<fn() -> Deps>,
37}
38
39impl<T, F, Deps> FilteredToolset<T, F, Deps>
40where
41    T: AbstractToolset<Deps>,
42    F: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
43{
44    /// Create a new filtered toolset.
45    pub fn new(inner: T, filter: F) -> Self {
46        Self {
47            inner,
48            filter,
49            id: None,
50            _phantom: PhantomData,
51        }
52    }
53
54    /// Set an ID for this filtered toolset.
55    #[must_use]
56    pub fn with_id(mut self, id: impl Into<String>) -> Self {
57        self.id = Some(id.into());
58        self
59    }
60
61    /// Get the inner toolset.
62    #[must_use]
63    pub fn inner(&self) -> &T {
64        &self.inner
65    }
66}
67
68#[async_trait]
69impl<T, F, Deps> AbstractToolset<Deps> for FilteredToolset<T, F, Deps>
70where
71    T: AbstractToolset<Deps>,
72    F: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
73    Deps: Send + Sync,
74{
75    fn id(&self) -> Option<&str> {
76        self.id.as_deref().or_else(|| self.inner.id())
77    }
78
79    fn type_name(&self) -> &'static str {
80        "FilteredToolset"
81    }
82
83    fn label(&self) -> String {
84        format!("FilteredToolset({})", self.inner.label())
85    }
86
87    async fn get_tools(
88        &self,
89        ctx: &RunContext<Deps>,
90    ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
91        let all_tools = self.inner.get_tools(ctx).await?;
92
93        Ok(all_tools
94            .into_iter()
95            .filter(|(_, tool)| (self.filter)(ctx, &tool.tool_def))
96            .collect())
97    }
98
99    async fn call_tool(
100        &self,
101        name: &str,
102        args: JsonValue,
103        ctx: &RunContext<Deps>,
104        tool: &ToolsetTool,
105    ) -> Result<ToolReturn, ToolError> {
106        // Verify the tool passes the filter
107        if !(self.filter)(ctx, &tool.tool_def) {
108            return Err(ToolError::not_found(format!(
109                "Tool '{}' is not available (filtered out)",
110                name
111            )));
112        }
113
114        self.inner.call_tool(name, args, ctx, tool).await
115    }
116
117    async fn enter(&self) -> Result<(), ToolError> {
118        self.inner.enter().await
119    }
120
121    async fn exit(&self) -> Result<(), ToolError> {
122        self.inner.exit().await
123    }
124}
125
126impl<T: std::fmt::Debug, F, Deps> std::fmt::Debug for FilteredToolset<T, F, Deps> {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        f.debug_struct("FilteredToolset")
129            .field("inner", &self.inner)
130            .field("id", &self.id)
131            .finish()
132    }
133}
134
135/// Common filter predicates.
136pub mod filters {
137    use serdes_ai_tools::{RunContext, ToolDefinition};
138
139    /// Filter that allows only tools with names in the given list.
140    pub fn allow_names<Deps>(
141        names: Vec<String>,
142    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
143        move |_, def| names.iter().any(|n| n == &def.name)
144    }
145
146    /// Filter that excludes tools with names in the given list.
147    pub fn deny_names<Deps>(
148        names: Vec<String>,
149    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
150        move |_, def| !names.iter().any(|n| n == &def.name)
151    }
152
153    /// Filter that allows tools matching a prefix.
154    pub fn name_prefix<Deps>(
155        prefix: String,
156    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
157        move |_, def| def.name.starts_with(&prefix)
158    }
159
160    /// Filter that allows tools matching a suffix.
161    pub fn name_suffix<Deps>(
162        suffix: String,
163    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
164        move |_, def| def.name.ends_with(&suffix)
165    }
166
167    /// Filter that allows tools containing a substring.
168    pub fn name_contains<Deps>(
169        substring: String,
170    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
171        move |_, def| def.name.contains(&substring)
172    }
173
174    /// Combine two filters with AND.
175    pub fn and<F1, F2, Deps>(
176        f1: F1,
177        f2: F2,
178    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync
179    where
180        F1: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
181        F2: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
182    {
183        move |ctx, def| f1(ctx, def) && f2(ctx, def)
184    }
185
186    /// Combine two filters with OR.
187    pub fn or<F1, F2, Deps>(
188        f1: F1,
189        f2: F2,
190    ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync
191    where
192        F1: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
193        F2: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
194    {
195        move |ctx, def| f1(ctx, def) || f2(ctx, def)
196    }
197
198    /// Negate a filter.
199    pub fn not<F, Deps>(f: F) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync
200    where
201        F: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
202    {
203        move |ctx, def| !f(ctx, def)
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use crate::FunctionToolset;
211    use async_trait::async_trait;
212    use serdes_ai_tools::Tool;
213
214    struct ToolA;
215
216    #[async_trait]
217    impl Tool<()> for ToolA {
218        fn definition(&self) -> ToolDefinition {
219            ToolDefinition::new("tool_a", "Tool A")
220        }
221
222        async fn call(
223            &self,
224            _ctx: &RunContext<()>,
225            _args: JsonValue,
226        ) -> Result<ToolReturn, ToolError> {
227            Ok(ToolReturn::text("A"))
228        }
229    }
230
231    struct ToolB;
232
233    #[async_trait]
234    impl Tool<()> for ToolB {
235        fn definition(&self) -> ToolDefinition {
236            ToolDefinition::new("tool_b", "Tool B")
237        }
238
239        async fn call(
240            &self,
241            _ctx: &RunContext<()>,
242            _args: JsonValue,
243        ) -> Result<ToolReturn, ToolError> {
244            Ok(ToolReturn::text("B"))
245        }
246    }
247
248    struct DangerousTool;
249
250    #[async_trait]
251    impl Tool<()> for DangerousTool {
252        fn definition(&self) -> ToolDefinition {
253            ToolDefinition::new("dangerous_delete", "Dangerous delete operation")
254        }
255
256        async fn call(
257            &self,
258            _ctx: &RunContext<()>,
259            _args: JsonValue,
260        ) -> Result<ToolReturn, ToolError> {
261            Ok(ToolReturn::text("Deleted!"))
262        }
263    }
264
265    #[tokio::test]
266    async fn test_filtered_toolset() {
267        let toolset = FunctionToolset::new()
268            .tool(ToolA)
269            .tool(ToolB)
270            .tool(DangerousTool);
271
272        // Filter out dangerous tools
273        let filtered = FilteredToolset::new(toolset, |_, def| !def.name.contains("dangerous"));
274
275        let ctx = RunContext::minimal("test");
276        let tools = filtered.get_tools(&ctx).await.unwrap();
277
278        assert_eq!(tools.len(), 2);
279        assert!(tools.contains_key("tool_a"));
280        assert!(tools.contains_key("tool_b"));
281        assert!(!tools.contains_key("dangerous_delete"));
282    }
283
284    #[tokio::test]
285    async fn test_filtered_toolset_call_blocked() {
286        let toolset = FunctionToolset::new().tool(ToolA).tool(DangerousTool);
287
288        let filtered = FilteredToolset::new(toolset, |_, def| !def.name.contains("dangerous"));
289
290        let ctx = RunContext::minimal("test");
291
292        // Create a fake tool definition for the dangerous tool
293        let fake_tool = ToolsetTool::new(ToolDefinition::new("dangerous_delete", "Dangerous"));
294
295        let result = filtered
296            .call_tool("dangerous_delete", serde_json::json!({}), &ctx, &fake_tool)
297            .await;
298
299        assert!(result.is_err());
300    }
301
302    #[tokio::test]
303    async fn test_filter_predicates_allow_names() {
304        let toolset = FunctionToolset::new().tool(ToolA).tool(ToolB);
305
306        let filtered =
307            FilteredToolset::new(toolset, filters::allow_names(vec!["tool_a".to_string()]));
308
309        let ctx = RunContext::minimal("test");
310        let tools = filtered.get_tools(&ctx).await.unwrap();
311
312        assert_eq!(tools.len(), 1);
313        assert!(tools.contains_key("tool_a"));
314    }
315
316    #[tokio::test]
317    async fn test_filter_predicates_deny_names() {
318        let toolset = FunctionToolset::new().tool(ToolA).tool(ToolB);
319
320        let filtered =
321            FilteredToolset::new(toolset, filters::deny_names(vec!["tool_b".to_string()]));
322
323        let ctx = RunContext::minimal("test");
324        let tools = filtered.get_tools(&ctx).await.unwrap();
325
326        assert_eq!(tools.len(), 1);
327        assert!(tools.contains_key("tool_a"));
328    }
329
330    #[tokio::test]
331    async fn test_filter_predicates_combined() {
332        let toolset = FunctionToolset::new()
333            .tool(ToolA)
334            .tool(ToolB)
335            .tool(DangerousTool);
336
337        // Allow tools starting with "tool" AND not containing "dangerous"
338        let filtered = FilteredToolset::new(
339            toolset,
340            filters::and(
341                filters::name_prefix("tool".to_string()),
342                filters::not(filters::name_contains("dangerous".to_string())),
343            ),
344        );
345
346        let ctx = RunContext::minimal("test");
347        let tools = filtered.get_tools(&ctx).await.unwrap();
348
349        assert_eq!(tools.len(), 2);
350    }
351}