Skip to main content

swink_agent/
tool_filter.rs

1//! Pattern-based tool filtering at registration time.
2//!
3//! [`ToolFilter`] uses exact, glob, and regex patterns to restrict which tools
4//! are available to the agent. Patterns are applied at registration time so that
5//! filtered tools never appear in the LLM prompt.
6//!
7//! # Example
8//!
9//! ```
10//! use swink_agent::{ToolFilter, ToolPattern};
11//!
12//! let filter = ToolFilter::new()
13//!     .with_allowed(vec![ToolPattern::parse("read_*")])
14//!     .with_rejected(vec![ToolPattern::parse("read_secret")]);
15//!
16//! assert!(filter.is_allowed("read_file"));
17//! assert!(!filter.is_allowed("read_secret"));
18//! assert!(!filter.is_allowed("bash"));
19//! ```
20
21use std::sync::Arc;
22
23use regex::Regex;
24
25use crate::tool::AgentTool;
26
27// ─── ToolPattern ────────────────────────────────────────────────────────────
28
29/// A pattern for matching tool names.
30///
31/// Auto-detected by [`parse()`](ToolPattern::parse):
32/// - Strings starting with `^` or ending with `$` → [`Regex`](ToolPattern::Regex)
33/// - Strings containing `*` or `?` → [`Glob`](ToolPattern::Glob)
34/// - Everything else → [`Exact`](ToolPattern::Exact)
35#[derive(Debug, Clone)]
36pub enum ToolPattern {
37    /// Match the tool name exactly.
38    Exact(String),
39    /// Match using glob syntax (`*` = any chars, `?` = single char).
40    Glob(String),
41    /// Match using a regular expression.
42    Regex(Regex),
43}
44
45impl ToolPattern {
46    /// Parse a pattern string, auto-detecting the pattern type.
47    #[must_use]
48    pub fn parse(pattern: &str) -> Self {
49        if pattern.starts_with('^') || pattern.ends_with('$') {
50            Regex::new(pattern).map_or_else(|_| Self::Exact(pattern.to_string()), Self::Regex)
51        } else if pattern.contains('*') || pattern.contains('?') {
52            Self::Glob(pattern.to_string())
53        } else {
54            Self::Exact(pattern.to_string())
55        }
56    }
57
58    /// Test whether this pattern matches the given tool name.
59    #[must_use]
60    pub fn matches(&self, name: &str) -> bool {
61        match self {
62            Self::Exact(pat) => name == pat,
63            Self::Glob(pat) => glob_matches(pat, name),
64            Self::Regex(re) => re.is_match(name),
65        }
66    }
67}
68
69/// Simple glob matching: `*` matches any sequence, `?` matches one char.
70fn glob_matches(pattern: &str, text: &str) -> bool {
71    let pattern_chars: Vec<char> = pattern.chars().collect();
72    let text_chars: Vec<char> = text.chars().collect();
73    let mut pattern_idx = 0;
74    let mut text_idx = 0;
75    let mut star_idx = None;
76    let mut match_after_star = 0;
77
78    while text_idx < text_chars.len() {
79        if pattern_idx < pattern_chars.len()
80            && (pattern_chars[pattern_idx] == '?'
81                || pattern_chars[pattern_idx] == text_chars[text_idx])
82        {
83            pattern_idx += 1;
84            text_idx += 1;
85            continue;
86        }
87
88        if pattern_idx < pattern_chars.len() && pattern_chars[pattern_idx] == '*' {
89            star_idx = Some(pattern_idx);
90            pattern_idx += 1;
91            match_after_star = text_idx;
92            continue;
93        }
94
95        if let Some(star) = star_idx {
96            pattern_idx = star + 1;
97            match_after_star += 1;
98            text_idx = match_after_star;
99            continue;
100        }
101
102        return false;
103    }
104
105    while pattern_idx < pattern_chars.len() && pattern_chars[pattern_idx] == '*' {
106        pattern_idx += 1;
107    }
108
109    pattern_idx == pattern_chars.len()
110}
111
112// ─── ToolFilter ─────────────────────────────────────────────────────────────
113
114/// Filters tools at registration time using pattern-based allow/reject lists.
115///
116/// When both `allowed` and `rejected` match a tool name, `rejected` takes
117/// precedence — the tool is excluded.
118#[derive(Debug, Clone, Default)]
119pub struct ToolFilter {
120    /// Patterns that a tool name must match to be included. Empty = allow all.
121    allowed: Vec<ToolPattern>,
122    /// Patterns that exclude a tool name. Takes precedence over `allowed`.
123    rejected: Vec<ToolPattern>,
124}
125
126impl ToolFilter {
127    /// Create a new empty filter (allows all tools).
128    #[must_use]
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    /// Set the allowed patterns.
134    #[must_use]
135    pub fn with_allowed(mut self, patterns: Vec<ToolPattern>) -> Self {
136        self.allowed = patterns;
137        self
138    }
139
140    /// Set the rejected patterns.
141    #[must_use]
142    pub fn with_rejected(mut self, patterns: Vec<ToolPattern>) -> Self {
143        self.rejected = patterns;
144        self
145    }
146
147    /// Test whether a tool name passes through this filter.
148    #[must_use]
149    pub fn is_allowed(&self, name: &str) -> bool {
150        // Rejected takes precedence.
151        if self.rejected.iter().any(|p| p.matches(name)) {
152            return false;
153        }
154        // If no allowed patterns, everything passes. Otherwise must match at least one.
155        if self.allowed.is_empty() {
156            return true;
157        }
158        self.allowed.iter().any(|p| p.matches(name))
159    }
160
161    /// Filter a list of tools, returning only those that pass the filter.
162    #[must_use]
163    pub fn filter_tools(&self, tools: Vec<Arc<dyn AgentTool>>) -> Vec<Arc<dyn AgentTool>> {
164        tools
165            .into_iter()
166            .filter(|t| self.is_allowed(t.name()))
167            .collect()
168    }
169}
170
171// ─── Compile-time Send + Sync assertions ────────────────────────────────────
172
173const _: () = {
174    const fn assert_send_sync<T: Send + Sync>() {}
175    assert_send_sync::<ToolFilter>();
176    assert_send_sync::<ToolPattern>();
177};
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn exact_pattern_matches() {
185        let pat = ToolPattern::parse("bash");
186        assert!(pat.matches("bash"));
187        assert!(!pat.matches("read_file"));
188    }
189
190    #[test]
191    fn glob_pattern_matches() {
192        let pat = ToolPattern::parse("read_*");
193        assert!(pat.matches("read_file"));
194        assert!(pat.matches("read_secret"));
195        assert!(!pat.matches("write_file"));
196    }
197
198    #[test]
199    fn glob_question_mark_matches_single_char() {
200        let pat = ToolPattern::parse("tool_?");
201        assert!(pat.matches("tool_a"));
202        assert!(!pat.matches("tool_ab"));
203    }
204
205    #[test]
206    fn glob_star_backtracks_without_regex() {
207        let pat = ToolPattern::parse("read_*_file");
208        assert!(pat.matches("read_secret_file"));
209        assert!(pat.matches("read_very_secret_file"));
210        assert!(!pat.matches("read_secret_dir"));
211    }
212
213    #[test]
214    fn glob_handles_unicode_chars() {
215        let pat = ToolPattern::parse("t?ol_*");
216        assert!(pat.matches("t🦀ol_alpha"));
217        assert!(!pat.matches("tool"));
218    }
219
220    #[test]
221    fn regex_pattern_matches() {
222        let pat = ToolPattern::parse("^file_.*$");
223        assert!(pat.matches("file_read"));
224        assert!(pat.matches("file_write"));
225        assert!(!pat.matches("bash"));
226    }
227
228    #[test]
229    fn rejected_takes_precedence() {
230        let filter = ToolFilter::new()
231            .with_allowed(vec![ToolPattern::parse("read_*")])
232            .with_rejected(vec![ToolPattern::parse("read_secret")]);
233
234        assert!(filter.is_allowed("read_file"));
235        assert!(!filter.is_allowed("read_secret"));
236    }
237
238    #[test]
239    fn empty_filter_allows_all() {
240        let filter = ToolFilter::new();
241        assert!(filter.is_allowed("anything"));
242        assert!(filter.is_allowed("bash"));
243    }
244
245    #[test]
246    fn allowed_only_restricts_to_matching() {
247        let filter = ToolFilter::new().with_allowed(vec![ToolPattern::parse("bash")]);
248        assert!(filter.is_allowed("bash"));
249        assert!(!filter.is_allowed("read_file"));
250    }
251
252    #[test]
253    fn rejected_only_excludes_matching() {
254        let filter = ToolFilter::new().with_rejected(vec![ToolPattern::parse("bash")]);
255        assert!(!filter.is_allowed("bash"));
256        assert!(filter.is_allowed("read_file"));
257    }
258
259    #[test]
260    fn invalid_regex_falls_back_to_exact() {
261        let pat = ToolPattern::parse("^[invalid");
262        // Falls back to exact match since regex is invalid
263        assert!(pat.matches("^[invalid"));
264    }
265
266    #[test]
267    fn parse_detects_pattern_type() {
268        assert!(matches!(ToolPattern::parse("exact"), ToolPattern::Exact(_)));
269        assert!(matches!(ToolPattern::parse("glob_*"), ToolPattern::Glob(_)));
270        assert!(matches!(
271            ToolPattern::parse("^regex$"),
272            ToolPattern::Regex(_)
273        ));
274    }
275}