swink_agent/
tool_filter.rs1use std::sync::Arc;
22
23use regex::Regex;
24
25use crate::tool::AgentTool;
26
27#[derive(Debug, Clone)]
36pub enum ToolPattern {
37 Exact(String),
39 Glob(String),
41 Regex(Regex),
43}
44
45impl ToolPattern {
46 #[must_use]
48 pub fn parse(pattern: &str) -> Self {
49 if pattern.starts_with('^') || pattern.ends_with('$') {
50 Regex::new(pattern).map_or_else(|_| Self::Exact(pattern.to_string()), Self::Regex)
51 } else if pattern.contains('*') || pattern.contains('?') {
52 Self::Glob(pattern.to_string())
53 } else {
54 Self::Exact(pattern.to_string())
55 }
56 }
57
58 #[must_use]
60 pub fn matches(&self, name: &str) -> bool {
61 match self {
62 Self::Exact(pat) => name == pat,
63 Self::Glob(pat) => glob_matches(pat, name),
64 Self::Regex(re) => re.is_match(name),
65 }
66 }
67}
68
69fn glob_matches(pattern: &str, text: &str) -> bool {
71 let pattern_chars: Vec<char> = pattern.chars().collect();
72 let text_chars: Vec<char> = text.chars().collect();
73 let mut pattern_idx = 0;
74 let mut text_idx = 0;
75 let mut star_idx = None;
76 let mut match_after_star = 0;
77
78 while text_idx < text_chars.len() {
79 if pattern_idx < pattern_chars.len()
80 && (pattern_chars[pattern_idx] == '?'
81 || pattern_chars[pattern_idx] == text_chars[text_idx])
82 {
83 pattern_idx += 1;
84 text_idx += 1;
85 continue;
86 }
87
88 if pattern_idx < pattern_chars.len() && pattern_chars[pattern_idx] == '*' {
89 star_idx = Some(pattern_idx);
90 pattern_idx += 1;
91 match_after_star = text_idx;
92 continue;
93 }
94
95 if let Some(star) = star_idx {
96 pattern_idx = star + 1;
97 match_after_star += 1;
98 text_idx = match_after_star;
99 continue;
100 }
101
102 return false;
103 }
104
105 while pattern_idx < pattern_chars.len() && pattern_chars[pattern_idx] == '*' {
106 pattern_idx += 1;
107 }
108
109 pattern_idx == pattern_chars.len()
110}
111
112#[derive(Debug, Clone, Default)]
119pub struct ToolFilter {
120 allowed: Vec<ToolPattern>,
122 rejected: Vec<ToolPattern>,
124}
125
126impl ToolFilter {
127 #[must_use]
129 pub fn new() -> Self {
130 Self::default()
131 }
132
133 #[must_use]
135 pub fn with_allowed(mut self, patterns: Vec<ToolPattern>) -> Self {
136 self.allowed = patterns;
137 self
138 }
139
140 #[must_use]
142 pub fn with_rejected(mut self, patterns: Vec<ToolPattern>) -> Self {
143 self.rejected = patterns;
144 self
145 }
146
147 #[must_use]
149 pub fn is_allowed(&self, name: &str) -> bool {
150 if self.rejected.iter().any(|p| p.matches(name)) {
152 return false;
153 }
154 if self.allowed.is_empty() {
156 return true;
157 }
158 self.allowed.iter().any(|p| p.matches(name))
159 }
160
161 #[must_use]
163 pub fn filter_tools(&self, tools: Vec<Arc<dyn AgentTool>>) -> Vec<Arc<dyn AgentTool>> {
164 tools
165 .into_iter()
166 .filter(|t| self.is_allowed(t.name()))
167 .collect()
168 }
169}
170
171const _: () = {
174 const fn assert_send_sync<T: Send + Sync>() {}
175 assert_send_sync::<ToolFilter>();
176 assert_send_sync::<ToolPattern>();
177};
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn exact_pattern_matches() {
185 let pat = ToolPattern::parse("bash");
186 assert!(pat.matches("bash"));
187 assert!(!pat.matches("read_file"));
188 }
189
190 #[test]
191 fn glob_pattern_matches() {
192 let pat = ToolPattern::parse("read_*");
193 assert!(pat.matches("read_file"));
194 assert!(pat.matches("read_secret"));
195 assert!(!pat.matches("write_file"));
196 }
197
198 #[test]
199 fn glob_question_mark_matches_single_char() {
200 let pat = ToolPattern::parse("tool_?");
201 assert!(pat.matches("tool_a"));
202 assert!(!pat.matches("tool_ab"));
203 }
204
205 #[test]
206 fn glob_star_backtracks_without_regex() {
207 let pat = ToolPattern::parse("read_*_file");
208 assert!(pat.matches("read_secret_file"));
209 assert!(pat.matches("read_very_secret_file"));
210 assert!(!pat.matches("read_secret_dir"));
211 }
212
213 #[test]
214 fn glob_handles_unicode_chars() {
215 let pat = ToolPattern::parse("t?ol_*");
216 assert!(pat.matches("t🦀ol_alpha"));
217 assert!(!pat.matches("tool"));
218 }
219
220 #[test]
221 fn regex_pattern_matches() {
222 let pat = ToolPattern::parse("^file_.*$");
223 assert!(pat.matches("file_read"));
224 assert!(pat.matches("file_write"));
225 assert!(!pat.matches("bash"));
226 }
227
228 #[test]
229 fn rejected_takes_precedence() {
230 let filter = ToolFilter::new()
231 .with_allowed(vec![ToolPattern::parse("read_*")])
232 .with_rejected(vec![ToolPattern::parse("read_secret")]);
233
234 assert!(filter.is_allowed("read_file"));
235 assert!(!filter.is_allowed("read_secret"));
236 }
237
238 #[test]
239 fn empty_filter_allows_all() {
240 let filter = ToolFilter::new();
241 assert!(filter.is_allowed("anything"));
242 assert!(filter.is_allowed("bash"));
243 }
244
245 #[test]
246 fn allowed_only_restricts_to_matching() {
247 let filter = ToolFilter::new().with_allowed(vec![ToolPattern::parse("bash")]);
248 assert!(filter.is_allowed("bash"));
249 assert!(!filter.is_allowed("read_file"));
250 }
251
252 #[test]
253 fn rejected_only_excludes_matching() {
254 let filter = ToolFilter::new().with_rejected(vec![ToolPattern::parse("bash")]);
255 assert!(!filter.is_allowed("bash"));
256 assert!(filter.is_allowed("read_file"));
257 }
258
259 #[test]
260 fn invalid_regex_falls_back_to_exact() {
261 let pat = ToolPattern::parse("^[invalid");
262 assert!(pat.matches("^[invalid"));
264 }
265
266 #[test]
267 fn parse_detects_pattern_type() {
268 assert!(matches!(ToolPattern::parse("exact"), ToolPattern::Exact(_)));
269 assert!(matches!(ToolPattern::parse("glob_*"), ToolPattern::Glob(_)));
270 assert!(matches!(
271 ToolPattern::parse("^regex$"),
272 ToolPattern::Regex(_)
273 ));
274 }
275}