Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub enum FlagStyle {
5    Strict,
6    Positional,
7}
8
9pub trait FlagSet {
10    fn contains_flag(&self, token: &str) -> bool;
11    fn contains_short(&self, byte: u8) -> bool;
12}
13
14impl FlagSet for WordSet {
15    fn contains_flag(&self, token: &str) -> bool {
16        self.contains(token)
17    }
18    fn contains_short(&self, byte: u8) -> bool {
19        self.contains_short(byte)
20    }
21}
22
23impl FlagSet for [String] {
24    fn contains_flag(&self, token: &str) -> bool {
25        self.iter().any(|f| f.as_str() == token)
26    }
27    fn contains_short(&self, byte: u8) -> bool {
28        self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
29    }
30}
31
32impl FlagSet for Vec<String> {
33    fn contains_flag(&self, token: &str) -> bool {
34        self.as_slice().contains_flag(token)
35    }
36    fn contains_short(&self, byte: u8) -> bool {
37        self.as_slice().contains_short(byte)
38    }
39}
40
41pub struct FlagPolicy {
42    pub standalone: WordSet,
43    pub valued: WordSet,
44    pub bare: bool,
45    pub max_positional: Option<usize>,
46    pub flag_style: FlagStyle,
47}
48
49impl FlagPolicy {
50    pub fn describe(&self) -> String {
51        use crate::docs::wordset_items;
52        let mut lines = Vec::new();
53        let standalone = wordset_items(&self.standalone);
54        if !standalone.is_empty() {
55            lines.push(format!("- Allowed standalone flags: {standalone}"));
56        }
57        let valued = wordset_items(&self.valued);
58        if !valued.is_empty() {
59            lines.push(format!("- Allowed valued flags: {valued}"));
60        }
61        if self.bare {
62            lines.push("- Bare invocation allowed".to_string());
63        }
64        if self.flag_style == FlagStyle::Positional {
65            lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
66        }
67        if lines.is_empty() && !self.bare {
68            return "- Positional arguments only".to_string();
69        }
70        lines.join("\n")
71    }
72
73    pub fn flag_summary(&self) -> String {
74        use crate::docs::wordset_items;
75        let mut parts = Vec::new();
76        let standalone = wordset_items(&self.standalone);
77        if !standalone.is_empty() {
78            parts.push(format!("Flags: {standalone}"));
79        }
80        let valued = wordset_items(&self.valued);
81        if !valued.is_empty() {
82            parts.push(format!("Valued: {valued}"));
83        }
84        if self.flag_style == FlagStyle::Positional {
85            parts.push("Positional args accepted".to_string());
86        }
87        parts.join(". ")
88    }
89}
90
91pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
92    check_flags(
93        tokens,
94        &policy.standalone,
95        &policy.valued,
96        policy.bare,
97        policy.max_positional,
98        policy.flag_style,
99    )
100}
101
102pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
103    tokens: &[Token],
104    standalone: &S,
105    valued: &V,
106    bare: bool,
107    max_positional: Option<usize>,
108    flag_style: FlagStyle,
109) -> bool {
110    if tokens.len() == 1 {
111        return bare;
112    }
113
114    let mut i = 1;
115    let mut positionals: usize = 0;
116    while i < tokens.len() {
117        let t = &tokens[i];
118
119        if *t == "--" {
120            positionals += tokens.len() - i - 1;
121            break;
122        }
123
124        if !t.starts_with('-') {
125            positionals += 1;
126            i += 1;
127            continue;
128        }
129
130        if standalone.contains_flag(t) {
131            i += 1;
132            continue;
133        }
134
135        if valued.contains_flag(t) {
136            i += 2;
137            continue;
138        }
139
140        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
141            if valued.contains_flag(flag) {
142                i += 1;
143                continue;
144            }
145            if flag_style == FlagStyle::Positional {
146                positionals += 1;
147                i += 1;
148                continue;
149            }
150            return false;
151        }
152
153        if t.starts_with("--") {
154            if flag_style == FlagStyle::Positional {
155                positionals += 1;
156                i += 1;
157                continue;
158            }
159            return false;
160        }
161
162        let bytes = t.as_bytes();
163        let mut j = 1;
164        while j < bytes.len() {
165            let b = bytes[j];
166            let is_last = j == bytes.len() - 1;
167            if standalone.contains_short(b) {
168                j += 1;
169                continue;
170            }
171            if valued.contains_short(b) {
172                if is_last {
173                    i += 1;
174                }
175                break;
176            }
177            if flag_style == FlagStyle::Positional {
178                positionals += 1;
179                break;
180            }
181            return false;
182        }
183        i += 1;
184    }
185    max_positional.is_none_or(|max| positionals <= max)
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    static TEST_POLICY: FlagPolicy = FlagPolicy {
193        standalone: WordSet::flags(&[
194            "--color", "--count", "--help", "--recursive", "--version",
195            "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
196        ]),
197        valued: WordSet::flags(&[
198            "--after-context", "--before-context", "--max-count",
199            "-A", "-B", "-m",
200        ]),
201        bare: false,
202        max_positional: None,
203        flag_style: FlagStyle::Strict,
204    };
205
206    fn toks(words: &[&str]) -> Vec<Token> {
207        words.iter().map(|s| Token::from_test(s)).collect()
208    }
209
210    #[test]
211    fn bare_denied_when_bare_false() {
212        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
213    }
214
215    #[test]
216    fn bare_allowed_when_bare_true() {
217        let policy = FlagPolicy {
218            standalone: WordSet::flags(&[]),
219            valued: WordSet::flags(&[]),
220            bare: true,
221            max_positional: None,
222            flag_style: FlagStyle::Strict,
223        };
224        assert!(check(&toks(&["uname"]), &policy));
225    }
226
227    #[test]
228    fn standalone_long_flag() {
229        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
230    }
231
232    #[test]
233    fn standalone_short_flag() {
234        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
235    }
236
237    #[test]
238    fn valued_long_flag_space() {
239        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
240    }
241
242    #[test]
243    fn valued_long_flag_eq() {
244        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
245    }
246
247    #[test]
248    fn valued_short_flag_space() {
249        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
250    }
251
252    #[test]
253    fn combined_standalone_short() {
254        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
255    }
256
257    #[test]
258    fn combined_short_with_valued_last() {
259        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
260    }
261
262    #[test]
263    fn combined_short_valued_mid_consumes_rest() {
264        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
265    }
266
267    #[test]
268    fn unknown_long_flag_denied() {
269        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
270    }
271
272    #[test]
273    fn unknown_short_flag_denied() {
274        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
275    }
276
277    #[test]
278    fn unknown_combined_short_denied() {
279        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
280    }
281
282    #[test]
283    fn unknown_long_eq_denied() {
284        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
285    }
286
287    #[test]
288    fn double_dash_stops_checking() {
289        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
290    }
291
292    #[test]
293    fn positional_args_allowed() {
294        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
295    }
296
297    #[test]
298    fn mixed_flags_and_positional() {
299        assert!(check(
300            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
301            &TEST_POLICY,
302        ));
303    }
304
305    #[test]
306    fn valued_short_in_explicit_form() {
307        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
308    }
309
310    #[test]
311    fn bare_dash_allowed_as_stdin() {
312        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
313    }
314
315    #[test]
316    fn valued_flag_at_end_without_value() {
317        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
318    }
319
320    #[test]
321    fn single_short_in_wordset_and_byte_array() {
322        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
323    }
324
325    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
326        standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
327        valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
328        bare: true,
329        max_positional: Some(1),
330        flag_style: FlagStyle::Strict,
331    };
332
333    #[test]
334    fn max_positional_within_limit() {
335        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
336    }
337
338    #[test]
339    fn max_positional_exceeded() {
340        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
341    }
342
343    #[test]
344    fn max_positional_with_flags_within_limit() {
345        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
346    }
347
348    #[test]
349    fn max_positional_with_flags_exceeded() {
350        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
351    }
352
353    #[test]
354    fn max_positional_after_double_dash() {
355        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
356    }
357
358    #[test]
359    fn max_positional_bare_allowed() {
360        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
361    }
362
363    static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
364        standalone: WordSet::flags(&["-E", "-e", "-n"]),
365        valued: WordSet::flags(&[]),
366        bare: true,
367        max_positional: None,
368        flag_style: FlagStyle::Positional,
369    };
370
371    #[test]
372    fn positional_style_unknown_long() {
373        assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
374    }
375
376    #[test]
377    fn positional_style_unknown_short() {
378        assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
379    }
380
381    #[test]
382    fn positional_style_dashes() {
383        assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
384    }
385
386    #[test]
387    fn positional_style_known_flags_still_work() {
388        assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
389    }
390
391    #[test]
392    fn positional_style_combo_known() {
393        assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
394    }
395
396    #[test]
397    fn positional_style_combo_unknown_byte() {
398        assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
399    }
400
401    #[test]
402    fn positional_style_unknown_eq() {
403        assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
404    }
405
406    #[test]
407    fn positional_style_with_max_positional() {
408        let policy = FlagPolicy {
409            standalone: WordSet::flags(&["-n"]),
410            valued: WordSet::flags(&[]),
411            bare: true,
412            max_positional: Some(2),
413            flag_style: FlagStyle::Positional,
414        };
415        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
416        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
417    }
418}