1use async_trait::async_trait;
7use serde_json::Value as JsonValue;
8use serdes_ai_tools::{RunContext, ToolDefinition, ToolError, ToolReturn};
9use std::collections::HashMap;
10use std::marker::PhantomData;
11
12use crate::{AbstractToolset, ToolsetTool};
13
14pub struct FilteredToolset<T, F, Deps = ()> {
33 inner: T,
34 filter: F,
35 id: Option<String>,
36 _phantom: PhantomData<fn() -> Deps>,
37}
38
39impl<T, F, Deps> FilteredToolset<T, F, Deps>
40where
41 T: AbstractToolset<Deps>,
42 F: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
43{
44 pub fn new(inner: T, filter: F) -> Self {
46 Self {
47 inner,
48 filter,
49 id: None,
50 _phantom: PhantomData,
51 }
52 }
53
54 #[must_use]
56 pub fn with_id(mut self, id: impl Into<String>) -> Self {
57 self.id = Some(id.into());
58 self
59 }
60
61 #[must_use]
63 pub fn inner(&self) -> &T {
64 &self.inner
65 }
66}
67
68#[async_trait]
69impl<T, F, Deps> AbstractToolset<Deps> for FilteredToolset<T, F, Deps>
70where
71 T: AbstractToolset<Deps>,
72 F: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
73 Deps: Send + Sync,
74{
75 fn id(&self) -> Option<&str> {
76 self.id.as_deref().or_else(|| self.inner.id())
77 }
78
79 fn type_name(&self) -> &'static str {
80 "FilteredToolset"
81 }
82
83 fn label(&self) -> String {
84 format!("FilteredToolset({})", self.inner.label())
85 }
86
87 async fn get_tools(
88 &self,
89 ctx: &RunContext<Deps>,
90 ) -> Result<HashMap<String, ToolsetTool>, ToolError> {
91 let all_tools = self.inner.get_tools(ctx).await?;
92
93 Ok(all_tools
94 .into_iter()
95 .filter(|(_, tool)| (self.filter)(ctx, &tool.tool_def))
96 .collect())
97 }
98
99 async fn call_tool(
100 &self,
101 name: &str,
102 args: JsonValue,
103 ctx: &RunContext<Deps>,
104 tool: &ToolsetTool,
105 ) -> Result<ToolReturn, ToolError> {
106 if !(self.filter)(ctx, &tool.tool_def) {
108 return Err(ToolError::not_found(format!(
109 "Tool '{}' is not available (filtered out)",
110 name
111 )));
112 }
113
114 self.inner.call_tool(name, args, ctx, tool).await
115 }
116
117 async fn enter(&self) -> Result<(), ToolError> {
118 self.inner.enter().await
119 }
120
121 async fn exit(&self) -> Result<(), ToolError> {
122 self.inner.exit().await
123 }
124}
125
126impl<T: std::fmt::Debug, F, Deps> std::fmt::Debug for FilteredToolset<T, F, Deps> {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 f.debug_struct("FilteredToolset")
129 .field("inner", &self.inner)
130 .field("id", &self.id)
131 .finish()
132 }
133}
134
135pub mod filters {
137 use serdes_ai_tools::{RunContext, ToolDefinition};
138
139 pub fn allow_names<Deps>(
141 names: Vec<String>,
142 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
143 move |_, def| names.iter().any(|n| n == &def.name)
144 }
145
146 pub fn deny_names<Deps>(
148 names: Vec<String>,
149 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
150 move |_, def| !names.iter().any(|n| n == &def.name)
151 }
152
153 pub fn name_prefix<Deps>(
155 prefix: String,
156 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
157 move |_, def| def.name.starts_with(&prefix)
158 }
159
160 pub fn name_suffix<Deps>(
162 suffix: String,
163 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
164 move |_, def| def.name.ends_with(&suffix)
165 }
166
167 pub fn name_contains<Deps>(
169 substring: String,
170 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync {
171 move |_, def| def.name.contains(&substring)
172 }
173
174 pub fn and<F1, F2, Deps>(
176 f1: F1,
177 f2: F2,
178 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync
179 where
180 F1: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
181 F2: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
182 {
183 move |ctx, def| f1(ctx, def) && f2(ctx, def)
184 }
185
186 pub fn or<F1, F2, Deps>(
188 f1: F1,
189 f2: F2,
190 ) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync
191 where
192 F1: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
193 F2: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
194 {
195 move |ctx, def| f1(ctx, def) || f2(ctx, def)
196 }
197
198 pub fn not<F, Deps>(f: F) -> impl Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync
200 where
201 F: Fn(&RunContext<Deps>, &ToolDefinition) -> bool + Send + Sync,
202 {
203 move |ctx, def| !f(ctx, def)
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use crate::FunctionToolset;
211 use async_trait::async_trait;
212 use serdes_ai_tools::Tool;
213
214 struct ToolA;
215
216 #[async_trait]
217 impl Tool<()> for ToolA {
218 fn definition(&self) -> ToolDefinition {
219 ToolDefinition::new("tool_a", "Tool A")
220 }
221
222 async fn call(
223 &self,
224 _ctx: &RunContext<()>,
225 _args: JsonValue,
226 ) -> Result<ToolReturn, ToolError> {
227 Ok(ToolReturn::text("A"))
228 }
229 }
230
231 struct ToolB;
232
233 #[async_trait]
234 impl Tool<()> for ToolB {
235 fn definition(&self) -> ToolDefinition {
236 ToolDefinition::new("tool_b", "Tool B")
237 }
238
239 async fn call(
240 &self,
241 _ctx: &RunContext<()>,
242 _args: JsonValue,
243 ) -> Result<ToolReturn, ToolError> {
244 Ok(ToolReturn::text("B"))
245 }
246 }
247
248 struct DangerousTool;
249
250 #[async_trait]
251 impl Tool<()> for DangerousTool {
252 fn definition(&self) -> ToolDefinition {
253 ToolDefinition::new("dangerous_delete", "Dangerous delete operation")
254 }
255
256 async fn call(
257 &self,
258 _ctx: &RunContext<()>,
259 _args: JsonValue,
260 ) -> Result<ToolReturn, ToolError> {
261 Ok(ToolReturn::text("Deleted!"))
262 }
263 }
264
265 #[tokio::test]
266 async fn test_filtered_toolset() {
267 let toolset = FunctionToolset::new()
268 .tool(ToolA)
269 .tool(ToolB)
270 .tool(DangerousTool);
271
272 let filtered = FilteredToolset::new(toolset, |_, def| !def.name.contains("dangerous"));
274
275 let ctx = RunContext::minimal("test");
276 let tools = filtered.get_tools(&ctx).await.unwrap();
277
278 assert_eq!(tools.len(), 2);
279 assert!(tools.contains_key("tool_a"));
280 assert!(tools.contains_key("tool_b"));
281 assert!(!tools.contains_key("dangerous_delete"));
282 }
283
284 #[tokio::test]
285 async fn test_filtered_toolset_call_blocked() {
286 let toolset = FunctionToolset::new().tool(ToolA).tool(DangerousTool);
287
288 let filtered = FilteredToolset::new(toolset, |_, def| !def.name.contains("dangerous"));
289
290 let ctx = RunContext::minimal("test");
291
292 let fake_tool = ToolsetTool::new(ToolDefinition::new("dangerous_delete", "Dangerous"));
294
295 let result = filtered
296 .call_tool("dangerous_delete", serde_json::json!({}), &ctx, &fake_tool)
297 .await;
298
299 assert!(result.is_err());
300 }
301
302 #[tokio::test]
303 async fn test_filter_predicates_allow_names() {
304 let toolset = FunctionToolset::new().tool(ToolA).tool(ToolB);
305
306 let filtered =
307 FilteredToolset::new(toolset, filters::allow_names(vec!["tool_a".to_string()]));
308
309 let ctx = RunContext::minimal("test");
310 let tools = filtered.get_tools(&ctx).await.unwrap();
311
312 assert_eq!(tools.len(), 1);
313 assert!(tools.contains_key("tool_a"));
314 }
315
316 #[tokio::test]
317 async fn test_filter_predicates_deny_names() {
318 let toolset = FunctionToolset::new().tool(ToolA).tool(ToolB);
319
320 let filtered =
321 FilteredToolset::new(toolset, filters::deny_names(vec!["tool_b".to_string()]));
322
323 let ctx = RunContext::minimal("test");
324 let tools = filtered.get_tools(&ctx).await.unwrap();
325
326 assert_eq!(tools.len(), 1);
327 assert!(tools.contains_key("tool_a"));
328 }
329
330 #[tokio::test]
331 async fn test_filter_predicates_combined() {
332 let toolset = FunctionToolset::new()
333 .tool(ToolA)
334 .tool(ToolB)
335 .tool(DangerousTool);
336
337 let filtered = FilteredToolset::new(
339 toolset,
340 filters::and(
341 filters::name_prefix("tool".to_string()),
342 filters::not(filters::name_contains("dangerous".to_string())),
343 ),
344 );
345
346 let ctx = RunContext::minimal("test");
347 let tools = filtered.get_tools(&ctx).await.unwrap();
348
349 assert_eq!(tools.len(), 2);
350 }
351}