Skip to main content

shell_mcp/
allowlist.rs

1//! Allowlist representation and matching.
2//!
3//! A [`Rule`] is an ordered list of glob patterns. Each pattern matches one
4//! command token positionally, with one exception: a trailing `**` matches
5//! zero or more remaining tokens. So `cargo build **` accepts every
6//! `cargo build ...` invocation, while `cargo build *` accepts exactly one
7//! extra argument.
8//!
9//! Patterns use the [`glob`] crate's syntax: `*` matches anything within a
10//! single token, `?` matches one character, `[abc]` matches a character class.
11
12use std::fmt;
13
14use glob::Pattern;
15
16/// A single allowlist entry, parsed once and matched many times.
17#[derive(Clone)]
18pub struct Rule {
19    /// The original textual form, kept for diagnostics and `shell_describe`.
20    raw: String,
21    /// Compiled glob patterns, one per token.
22    patterns: Vec<Pattern>,
23    /// True if the final pattern is the literal `**` rest-matcher.
24    rest_match: bool,
25    /// Where the rule came from (TOML path or `<defaults>`).
26    source: String,
27}
28
29impl Rule {
30    /// Parse a single allowlist entry like `git log --oneline *`.
31    pub fn parse(raw: impl Into<String>, source: impl Into<String>) -> Result<Self, RuleError> {
32        let raw = raw.into();
33        let source = source.into();
34        let tokens = shlex::split(raw.trim()).ok_or_else(|| RuleError::Parse {
35            raw: raw.clone(),
36            reason: "unbalanced quotes".to_string(),
37        })?;
38        if tokens.is_empty() {
39            return Err(RuleError::Parse {
40                raw: raw.clone(),
41                reason: "empty rule".to_string(),
42            });
43        }
44        let rest_match = tokens.last().map(|t| t == "**").unwrap_or(false);
45        let pattern_tokens = if rest_match {
46            &tokens[..tokens.len() - 1]
47        } else {
48            &tokens[..]
49        };
50        let patterns = pattern_tokens
51            .iter()
52            .map(|t| {
53                Pattern::new(t).map_err(|e| RuleError::Parse {
54                    raw: raw.clone(),
55                    reason: format!("invalid glob `{t}`: {e}"),
56                })
57            })
58            .collect::<Result<Vec<_>, _>>()?;
59        Ok(Self {
60            raw,
61            patterns,
62            rest_match,
63            source,
64        })
65    }
66
67    /// True if every token in `cmd` is matched by the corresponding pattern.
68    pub fn matches(&self, cmd: &[String]) -> bool {
69        if self.rest_match {
70            if cmd.len() < self.patterns.len() {
71                return false;
72            }
73        } else if cmd.len() != self.patterns.len() {
74            return false;
75        }
76        self.patterns
77            .iter()
78            .zip(cmd.iter())
79            .all(|(pat, tok)| pat.matches(tok))
80    }
81
82    pub fn raw(&self) -> &str {
83        &self.raw
84    }
85
86    pub fn source(&self) -> &str {
87        &self.source
88    }
89}
90
91impl fmt::Debug for Rule {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        f.debug_struct("Rule")
94            .field("raw", &self.raw)
95            .field("source", &self.source)
96            .finish()
97    }
98}
99
100#[derive(Debug, thiserror::Error)]
101pub enum RuleError {
102    #[error("could not parse rule `{raw}`: {reason}")]
103    Parse { raw: String, reason: String },
104}
105
106/// An ordered collection of rules. Rules are matched in order; the first
107/// match wins (used to surface *which* rule allowed a command).
108#[derive(Clone, Debug, Default)]
109pub struct Allowlist {
110    rules: Vec<Rule>,
111}
112
113impl Allowlist {
114    pub fn new() -> Self {
115        Self::default()
116    }
117
118    pub fn from_rules(rules: Vec<Rule>) -> Self {
119        Self { rules }
120    }
121
122    pub fn push(&mut self, rule: Rule) {
123        self.rules.push(rule);
124    }
125
126    pub fn extend(&mut self, other: Allowlist) {
127        self.rules.extend(other.rules);
128    }
129
130    pub fn rules(&self) -> &[Rule] {
131        &self.rules
132    }
133
134    /// Return the first rule that matches, or `None` if the command is denied.
135    pub fn find_match(&self, cmd: &[String]) -> Option<&Rule> {
136        self.rules.iter().find(|r| r.matches(cmd))
137    }
138}
139
140/// Default *read-only* allowlist for the current platform.
141///
142/// "Read-only" is a descriptive label, not a guarantee — `git log` reads from
143/// `.git`, but technically `cargo metadata` may write to a target dir cache.
144/// The intent is to ship a useful exploration toolkit out of the box.
145pub fn platform_defaults() -> Allowlist {
146    let raw_rules: &[&str] = if cfg!(windows) {
147        &[
148            "dir",
149            "dir **",
150            "type *",
151            "type **",
152            "findstr **",
153            "where *",
154            "where **",
155            "tree",
156            "tree /F",
157            "tree /F **",
158            "git status",
159            "git status **",
160            "git log",
161            "git log **",
162            "git diff",
163            "git diff **",
164            "git show",
165            "git show **",
166            "git branch",
167            "git branch **",
168            "git remote -v",
169            "cargo metadata",
170            "cargo metadata **",
171            "cargo tree",
172            "cargo tree **",
173            "cargo --version",
174            "rustc --version",
175            "whoami",
176        ]
177    } else {
178        &[
179            "ls",
180            "ls **",
181            "cat *",
182            "cat **",
183            "head *",
184            "head **",
185            "tail *",
186            "tail **",
187            "wc *",
188            "wc **",
189            "grep **",
190            "rg **",
191            "find **",
192            "tree",
193            "tree **",
194            "file *",
195            "file **",
196            "stat *",
197            "stat **",
198            "pwd",
199            "which *",
200            "which **",
201            "echo",
202            "echo **",
203            "env",
204            "git status",
205            "git status **",
206            "git log",
207            "git log **",
208            "git diff",
209            "git diff **",
210            "git show",
211            "git show **",
212            "git branch",
213            "git branch **",
214            "git remote -v",
215            "cargo metadata",
216            "cargo metadata **",
217            "cargo tree",
218            "cargo tree **",
219            "cargo --version",
220            "rustc --version",
221        ]
222    };
223    let rules = raw_rules
224        .iter()
225        .map(|r| Rule::parse(*r, "<defaults>").expect("default rules must parse"))
226        .collect();
227    Allowlist::from_rules(rules)
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    fn tokens(s: &str) -> Vec<String> {
235        shlex::split(s).unwrap()
236    }
237
238    #[test]
239    fn exact_match() {
240        let r = Rule::parse("git status", "test").unwrap();
241        assert!(r.matches(&tokens("git status")));
242        assert!(!r.matches(&tokens("git status --short")));
243    }
244
245    #[test]
246    fn single_glob_matches_one_token() {
247        let r = Rule::parse("cargo build *", "test").unwrap();
248        assert!(r.matches(&tokens("cargo build --release")));
249        assert!(!r.matches(&tokens("cargo build")));
250        assert!(!r.matches(&tokens("cargo build --release --offline")));
251    }
252
253    #[test]
254    fn double_star_matches_rest() {
255        let r = Rule::parse("git log **", "test").unwrap();
256        assert!(r.matches(&tokens("git log")));
257        assert!(r.matches(&tokens("git log --oneline -n 5")));
258    }
259
260    #[test]
261    fn defaults_allow_pwd() {
262        let al = platform_defaults();
263        if !cfg!(windows) {
264            assert!(al.find_match(&tokens("pwd")).is_some());
265            assert!(al.find_match(&tokens("git status")).is_some());
266            assert!(al.find_match(&tokens("git log --oneline")).is_some());
267        }
268    }
269
270    #[test]
271    fn unknown_commands_are_denied_by_default() {
272        let al = platform_defaults();
273        assert!(al.find_match(&tokens("dangerous-thing --yolo")).is_none());
274    }
275}