Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub enum FlagStyle {
5    Strict,
6    Positional,
7}
8
9pub trait FlagSet {
10    fn contains_flag(&self, token: &str) -> bool;
11    fn contains_short(&self, byte: u8) -> bool;
12}
13
14impl FlagSet for WordSet {
15    fn contains_flag(&self, token: &str) -> bool {
16        self.contains(token)
17    }
18    fn contains_short(&self, byte: u8) -> bool {
19        self.contains_short(byte)
20    }
21}
22
23impl FlagSet for [String] {
24    fn contains_flag(&self, token: &str) -> bool {
25        self.iter().any(|f| f.as_str() == token)
26    }
27    fn contains_short(&self, byte: u8) -> bool {
28        self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
29    }
30}
31
32impl FlagSet for Vec<String> {
33    fn contains_flag(&self, token: &str) -> bool {
34        self.as_slice().contains_flag(token)
35    }
36    fn contains_short(&self, byte: u8) -> bool {
37        self.as_slice().contains_short(byte)
38    }
39}
40
41pub struct FlagPolicy {
42    pub standalone: WordSet,
43    pub valued: WordSet,
44    pub bare: bool,
45    pub max_positional: Option<usize>,
46    pub flag_style: FlagStyle,
47    pub numeric_dash: bool,
48}
49
50impl FlagPolicy {
51    pub fn describe(&self) -> String {
52        use crate::docs::wordset_items;
53        let mut lines = Vec::new();
54        let standalone = wordset_items(&self.standalone);
55        if !standalone.is_empty() {
56            lines.push(format!("- Allowed standalone flags: {standalone}"));
57        }
58        let valued = wordset_items(&self.valued);
59        if !valued.is_empty() {
60            lines.push(format!("- Allowed valued flags: {valued}"));
61        }
62        if self.bare {
63            lines.push("- Bare invocation allowed".to_string());
64        }
65        if self.flag_style == FlagStyle::Positional {
66            lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
67        }
68        if self.numeric_dash {
69            lines.push("- Numeric shorthand accepted (e.g. -20 for -n 20)".to_string());
70        }
71        if lines.is_empty() && !self.bare {
72            return "- Positional arguments only".to_string();
73        }
74        lines.join("\n")
75    }
76
77}
78
79pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
80    check_flags(
81        tokens,
82        &policy.standalone,
83        &policy.valued,
84        policy.bare,
85        policy.max_positional,
86        policy.flag_style,
87        policy.numeric_dash,
88    )
89}
90
91pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
92    tokens: &[Token],
93    standalone: &S,
94    valued: &V,
95    bare: bool,
96    max_positional: Option<usize>,
97    flag_style: FlagStyle,
98    numeric_dash: bool,
99) -> bool {
100    if tokens.len() == 1 {
101        return bare;
102    }
103
104    let mut i = 1;
105    let mut positionals: usize = 0;
106    while i < tokens.len() {
107        let t = &tokens[i];
108
109        if *t == "--" {
110            positionals += tokens.len() - i - 1;
111            break;
112        }
113
114        if !t.starts_with('-') {
115            positionals += 1;
116            i += 1;
117            continue;
118        }
119
120        if numeric_dash && t.len() > 1 && t[1..].bytes().all(|b| b.is_ascii_digit()) {
121            i += 1;
122            continue;
123        }
124
125        if standalone.contains_flag(t) {
126            i += 1;
127            continue;
128        }
129
130        if valued.contains_flag(t) {
131            i += 2;
132            continue;
133        }
134
135        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
136            if valued.contains_flag(flag) {
137                i += 1;
138                continue;
139            }
140            if flag_style == FlagStyle::Positional {
141                positionals += 1;
142                i += 1;
143                continue;
144            }
145            return false;
146        }
147
148        if t.starts_with("--") {
149            if flag_style == FlagStyle::Positional {
150                positionals += 1;
151                i += 1;
152                continue;
153            }
154            return false;
155        }
156
157        let bytes = t.as_bytes();
158        let mut j = 1;
159        while j < bytes.len() {
160            let b = bytes[j];
161            let is_last = j == bytes.len() - 1;
162            if standalone.contains_short(b) {
163                j += 1;
164                continue;
165            }
166            if valued.contains_short(b) {
167                if is_last {
168                    i += 1;
169                }
170                break;
171            }
172            if flag_style == FlagStyle::Positional {
173                positionals += 1;
174                break;
175            }
176            return false;
177        }
178        i += 1;
179    }
180    max_positional.is_none_or(|max| positionals <= max)
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    static TEST_POLICY: FlagPolicy = FlagPolicy {
188        standalone: WordSet::flags(&[
189            "--color", "--count", "--help", "--recursive", "--version",
190            "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
191        ]),
192        valued: WordSet::flags(&[
193            "--after-context", "--before-context", "--max-count",
194            "-A", "-B", "-m",
195        ]),
196        bare: false,
197        max_positional: None,
198        flag_style: FlagStyle::Strict,
199        numeric_dash: false,
200    };
201
202    fn toks(words: &[&str]) -> Vec<Token> {
203        words.iter().map(|s| Token::from_test(s)).collect()
204    }
205
206    #[test]
207    fn bare_denied_when_bare_false() {
208        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
209    }
210
211    #[test]
212    fn bare_allowed_when_bare_true() {
213        let policy = FlagPolicy {
214            standalone: WordSet::flags(&[]),
215            valued: WordSet::flags(&[]),
216            bare: true,
217            max_positional: None,
218            flag_style: FlagStyle::Strict,
219            numeric_dash: false,
220        };
221        assert!(check(&toks(&["uname"]), &policy));
222    }
223
224    #[test]
225    fn standalone_long_flag() {
226        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
227    }
228
229    #[test]
230    fn standalone_short_flag() {
231        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
232    }
233
234    #[test]
235    fn valued_long_flag_space() {
236        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
237    }
238
239    #[test]
240    fn valued_long_flag_eq() {
241        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
242    }
243
244    #[test]
245    fn valued_short_flag_space() {
246        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
247    }
248
249    #[test]
250    fn combined_standalone_short() {
251        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
252    }
253
254    #[test]
255    fn combined_short_with_valued_last() {
256        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
257    }
258
259    #[test]
260    fn combined_short_valued_mid_consumes_rest() {
261        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
262    }
263
264    #[test]
265    fn unknown_long_flag_denied() {
266        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
267    }
268
269    #[test]
270    fn unknown_short_flag_denied() {
271        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
272    }
273
274    #[test]
275    fn unknown_combined_short_denied() {
276        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
277    }
278
279    #[test]
280    fn unknown_long_eq_denied() {
281        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
282    }
283
284    #[test]
285    fn double_dash_stops_checking() {
286        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
287    }
288
289    #[test]
290    fn positional_args_allowed() {
291        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
292    }
293
294    #[test]
295    fn mixed_flags_and_positional() {
296        assert!(check(
297            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
298            &TEST_POLICY,
299        ));
300    }
301
302    #[test]
303    fn valued_short_in_explicit_form() {
304        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
305    }
306
307    #[test]
308    fn bare_dash_allowed_as_stdin() {
309        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
310    }
311
312    #[test]
313    fn valued_flag_at_end_without_value() {
314        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
315    }
316
317    #[test]
318    fn single_short_in_wordset_and_byte_array() {
319        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
320    }
321
322    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
323        standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
324        valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
325        bare: true,
326        max_positional: Some(1),
327        flag_style: FlagStyle::Strict,
328        numeric_dash: false,
329    };
330
331    #[test]
332    fn max_positional_within_limit() {
333        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
334    }
335
336    #[test]
337    fn max_positional_exceeded() {
338        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
339    }
340
341    #[test]
342    fn max_positional_with_flags_within_limit() {
343        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
344    }
345
346    #[test]
347    fn max_positional_with_flags_exceeded() {
348        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
349    }
350
351    #[test]
352    fn max_positional_after_double_dash() {
353        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
354    }
355
356    #[test]
357    fn max_positional_bare_allowed() {
358        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
359    }
360
361    static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
362        standalone: WordSet::flags(&["-E", "-e", "-n"]),
363        valued: WordSet::flags(&[]),
364        bare: true,
365        max_positional: None,
366        flag_style: FlagStyle::Positional,
367        numeric_dash: false,
368    };
369
370    #[test]
371    fn positional_style_unknown_long() {
372        assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
373    }
374
375    #[test]
376    fn positional_style_unknown_short() {
377        assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
378    }
379
380    #[test]
381    fn positional_style_dashes() {
382        assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
383    }
384
385    #[test]
386    fn positional_style_known_flags_still_work() {
387        assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
388    }
389
390    #[test]
391    fn positional_style_combo_known() {
392        assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
393    }
394
395    #[test]
396    fn positional_style_combo_unknown_byte() {
397        assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
398    }
399
400    #[test]
401    fn positional_style_unknown_eq() {
402        assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
403    }
404
405    #[test]
406    fn positional_style_with_max_positional() {
407        let policy = FlagPolicy {
408            standalone: WordSet::flags(&["-n"]),
409            valued: WordSet::flags(&[]),
410            bare: true,
411            max_positional: Some(2),
412            flag_style: FlagStyle::Positional,
413            numeric_dash: false,
414        };
415        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
416        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
417    }
418
419    static NUMERIC_DASH_POLICY: FlagPolicy = FlagPolicy {
420        standalone: WordSet::flags(&[
421            "--help", "--quiet", "--verbose", "--version",
422            "-V", "-h", "-q", "-v", "-z",
423        ]),
424        valued: WordSet::flags(&["--bytes", "--lines", "-c", "-n"]),
425        bare: true,
426        max_positional: None,
427        flag_style: FlagStyle::Strict,
428        numeric_dash: true,
429    };
430
431    #[test]
432    fn numeric_dash_single_digit() {
433        assert!(check(&toks(&["head", "-5"]), &NUMERIC_DASH_POLICY));
434    }
435
436    #[test]
437    fn numeric_dash_multi_digit() {
438        assert!(check(&toks(&["head", "-20"]), &NUMERIC_DASH_POLICY));
439    }
440
441    #[test]
442    fn numeric_dash_large_number() {
443        assert!(check(&toks(&["head", "-1000"]), &NUMERIC_DASH_POLICY));
444    }
445
446    #[test]
447    fn numeric_dash_with_file_arg() {
448        assert!(check(&toks(&["head", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
449    }
450
451    #[test]
452    fn numeric_dash_with_other_flags() {
453        assert!(check(&toks(&["head", "-q", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
454    }
455
456    #[test]
457    fn numeric_dash_zero() {
458        assert!(check(&toks(&["head", "-0"]), &NUMERIC_DASH_POLICY));
459    }
460
461    #[test]
462    fn numeric_dash_still_rejects_unknown_flags() {
463        assert!(!check(&toks(&["head", "-x"]), &NUMERIC_DASH_POLICY));
464    }
465
466    #[test]
467    fn numeric_dash_rejects_mixed_alpha_num() {
468        assert!(!check(&toks(&["head", "-20x"]), &NUMERIC_DASH_POLICY));
469    }
470
471    #[test]
472    fn numeric_dash_disabled_rejects_multi_digit() {
473        assert!(!check(&toks(&["grep", "-20", "pattern"]), &TEST_POLICY));
474    }
475}