1use std::fmt;
13
14use glob::Pattern;
15
16#[derive(Clone)]
18pub struct Rule {
19 raw: String,
21 patterns: Vec<Pattern>,
23 rest_match: bool,
25 source: String,
27}
28
29impl Rule {
30 pub fn parse(raw: impl Into<String>, source: impl Into<String>) -> Result<Self, RuleError> {
32 let raw = raw.into();
33 let source = source.into();
34 let tokens = shlex::split(raw.trim()).ok_or_else(|| RuleError::Parse {
35 raw: raw.clone(),
36 reason: "unbalanced quotes".to_string(),
37 })?;
38 if tokens.is_empty() {
39 return Err(RuleError::Parse {
40 raw: raw.clone(),
41 reason: "empty rule".to_string(),
42 });
43 }
44 let rest_match = tokens.last().map(|t| t == "**").unwrap_or(false);
45 let pattern_tokens = if rest_match {
46 &tokens[..tokens.len() - 1]
47 } else {
48 &tokens[..]
49 };
50 let patterns = pattern_tokens
51 .iter()
52 .map(|t| {
53 Pattern::new(t).map_err(|e| RuleError::Parse {
54 raw: raw.clone(),
55 reason: format!("invalid glob `{t}`: {e}"),
56 })
57 })
58 .collect::<Result<Vec<_>, _>>()?;
59 Ok(Self {
60 raw,
61 patterns,
62 rest_match,
63 source,
64 })
65 }
66
67 pub fn matches(&self, cmd: &[String]) -> bool {
69 if self.rest_match {
70 if cmd.len() < self.patterns.len() {
71 return false;
72 }
73 } else if cmd.len() != self.patterns.len() {
74 return false;
75 }
76 self.patterns
77 .iter()
78 .zip(cmd.iter())
79 .all(|(pat, tok)| pat.matches(tok))
80 }
81
82 pub fn raw(&self) -> &str {
83 &self.raw
84 }
85
86 pub fn source(&self) -> &str {
87 &self.source
88 }
89}
90
91impl fmt::Debug for Rule {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 f.debug_struct("Rule")
94 .field("raw", &self.raw)
95 .field("source", &self.source)
96 .finish()
97 }
98}
99
100#[derive(Debug, thiserror::Error)]
101pub enum RuleError {
102 #[error("could not parse rule `{raw}`: {reason}")]
103 Parse { raw: String, reason: String },
104}
105
106#[derive(Clone, Debug, Default)]
109pub struct Allowlist {
110 rules: Vec<Rule>,
111}
112
113impl Allowlist {
114 pub fn new() -> Self {
115 Self::default()
116 }
117
118 pub fn from_rules(rules: Vec<Rule>) -> Self {
119 Self { rules }
120 }
121
122 pub fn push(&mut self, rule: Rule) {
123 self.rules.push(rule);
124 }
125
126 pub fn extend(&mut self, other: Allowlist) {
127 self.rules.extend(other.rules);
128 }
129
130 pub fn rules(&self) -> &[Rule] {
131 &self.rules
132 }
133
134 pub fn find_match(&self, cmd: &[String]) -> Option<&Rule> {
136 self.rules.iter().find(|r| r.matches(cmd))
137 }
138}
139
140pub fn platform_defaults() -> Allowlist {
146 let raw_rules: &[&str] = if cfg!(windows) {
147 &[
148 "dir",
149 "dir **",
150 "type *",
151 "type **",
152 "findstr **",
153 "where *",
154 "where **",
155 "tree",
156 "tree /F",
157 "tree /F **",
158 "git status",
159 "git status **",
160 "git log",
161 "git log **",
162 "git diff",
163 "git diff **",
164 "git show",
165 "git show **",
166 "git branch",
167 "git branch **",
168 "git remote -v",
169 "cargo metadata",
170 "cargo metadata **",
171 "cargo tree",
172 "cargo tree **",
173 "cargo --version",
174 "rustc --version",
175 "whoami",
176 ]
177 } else {
178 &[
179 "ls",
180 "ls **",
181 "cat *",
182 "cat **",
183 "head *",
184 "head **",
185 "tail *",
186 "tail **",
187 "wc *",
188 "wc **",
189 "grep **",
190 "rg **",
191 "find **",
192 "tree",
193 "tree **",
194 "file *",
195 "file **",
196 "stat *",
197 "stat **",
198 "pwd",
199 "which *",
200 "which **",
201 "echo",
202 "echo **",
203 "env",
204 "git status",
205 "git status **",
206 "git log",
207 "git log **",
208 "git diff",
209 "git diff **",
210 "git show",
211 "git show **",
212 "git branch",
213 "git branch **",
214 "git remote -v",
215 "cargo metadata",
216 "cargo metadata **",
217 "cargo tree",
218 "cargo tree **",
219 "cargo --version",
220 "rustc --version",
221 ]
222 };
223 let rules = raw_rules
224 .iter()
225 .map(|r| Rule::parse(*r, "<defaults>").expect("default rules must parse"))
226 .collect();
227 Allowlist::from_rules(rules)
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 fn tokens(s: &str) -> Vec<String> {
235 shlex::split(s).unwrap()
236 }
237
238 #[test]
239 fn exact_match() {
240 let r = Rule::parse("git status", "test").unwrap();
241 assert!(r.matches(&tokens("git status")));
242 assert!(!r.matches(&tokens("git status --short")));
243 }
244
245 #[test]
246 fn single_glob_matches_one_token() {
247 let r = Rule::parse("cargo build *", "test").unwrap();
248 assert!(r.matches(&tokens("cargo build --release")));
249 assert!(!r.matches(&tokens("cargo build")));
250 assert!(!r.matches(&tokens("cargo build --release --offline")));
251 }
252
253 #[test]
254 fn double_star_matches_rest() {
255 let r = Rule::parse("git log **", "test").unwrap();
256 assert!(r.matches(&tokens("git log")));
257 assert!(r.matches(&tokens("git log --oneline -n 5")));
258 }
259
260 #[test]
261 fn defaults_allow_pwd() {
262 let al = platform_defaults();
263 if !cfg!(windows) {
264 assert!(al.find_match(&tokens("pwd")).is_some());
265 assert!(al.find_match(&tokens("git status")).is_some());
266 assert!(al.find_match(&tokens("git log --oneline")).is_some());
267 }
268 }
269
270 #[test]
271 fn unknown_commands_are_denied_by_default() {
272 let al = platform_defaults();
273 assert!(al.find_match(&tokens("dangerous-thing --yolo")).is_none());
274 }
275}