Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3/// Whether unrecognized flag-shaped tokens are denied or silently accepted
4/// as positional arguments. The default (Strict) makes the allowlist
5/// authoritative — any unrecognized `-X` or `--foo` is denied.
6#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
7pub enum UnknownTolerance {
8    /// Deny every unrecognized flag-shaped token. The safe default.
9    #[default]
10    Strict,
11    /// Accept unknown single-dash tokens (`-X`, `-help`, `-mayDie`) as
12    /// positional. Reject unknown double-dash. Use for tools like
13    /// `pdftotext` that have single-dash long flags.
14    Short,
15    /// Accept unknown double-dash tokens (`--foo`, `--foo=value`) as
16    /// positional. Reject unknown single-dash. Dangerous: most modern
17    /// destructive flags are double-dash, so enabling this can silently
18    /// accept mutating options. Reserved for tools with genuinely
19    /// unbounded long-flag surfaces (AWS CLI service flags).
20    Long,
21    /// Accept both single-dash and double-dash unknowns as positional.
22    /// Most permissive; combines the cost of `Short` and `Long`.
23    Both,
24}
25
26impl UnknownTolerance {
27    pub const fn allows_short(self) -> bool {
28        matches!(self, Self::Short | Self::Both)
29    }
30    pub const fn allows_long(self) -> bool {
31        matches!(self, Self::Long | Self::Both)
32    }
33}
34
35/// How the dispatcher treats tokens that look like flags but aren't in the
36/// allowlist. `unknown` controls flag-shaped unknowns; `numeric_dash` opts
37/// into `-NUMBER` shorthand (e.g. `head -20`).
38#[derive(Clone, Copy, Debug, Default)]
39pub struct FlagTolerance {
40    pub unknown: UnknownTolerance,
41    pub numeric_dash: bool,
42}
43
44impl FlagTolerance {
45    /// Strict allowlist: deny every unrecognized flag-shaped token.
46    /// `const`-callable for use in static `FlagPolicy` literals.
47    pub const fn strict() -> Self {
48        Self { unknown: UnknownTolerance::Strict, numeric_dash: false }
49    }
50}
51
52pub trait FlagSet {
53    fn contains_flag(&self, token: &str) -> bool;
54    fn contains_short(&self, byte: u8) -> bool;
55}
56
57impl FlagSet for WordSet {
58    fn contains_flag(&self, token: &str) -> bool {
59        self.contains(token)
60    }
61    fn contains_short(&self, byte: u8) -> bool {
62        self.contains_short(byte)
63    }
64}
65
66impl FlagSet for [String] {
67    fn contains_flag(&self, token: &str) -> bool {
68        self.iter().any(|f| f.as_str() == token)
69    }
70    fn contains_short(&self, byte: u8) -> bool {
71        self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
72    }
73}
74
75impl FlagSet for Vec<String> {
76    fn contains_flag(&self, token: &str) -> bool {
77        self.as_slice().contains_flag(token)
78    }
79    fn contains_short(&self, byte: u8) -> bool {
80        self.as_slice().contains_short(byte)
81    }
82}
83
84pub struct FlagPolicy {
85    pub standalone: WordSet,
86    pub valued: WordSet,
87    pub bare: bool,
88    pub max_positional: Option<usize>,
89    pub tolerance: FlagTolerance,
90}
91
92impl FlagPolicy {
93    pub fn describe(&self) -> String {
94        use crate::docs::wordset_items;
95        let mut lines = Vec::new();
96        let standalone = wordset_items(&self.standalone);
97        if !standalone.is_empty() {
98            lines.push(format!("- Allowed standalone flags: {standalone}"));
99        }
100        let valued = wordset_items(&self.valued);
101        if !valued.is_empty() {
102            lines.push(format!("- Allowed valued flags: {valued}"));
103        }
104        if self.bare {
105            lines.push("- Bare invocation allowed".to_string());
106        }
107        if self.tolerance.unknown != UnknownTolerance::Strict {
108            lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
109        }
110        if self.tolerance.numeric_dash {
111            lines.push("- Numeric shorthand accepted (e.g. -20 for -n 20)".to_string());
112        }
113        if lines.is_empty() && !self.bare {
114            return "- Positional arguments only".to_string();
115        }
116        lines.join("\n")
117    }
118
119}
120
121pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
122    check_flags(
123        tokens,
124        &policy.standalone,
125        &policy.valued,
126        policy.bare,
127        policy.max_positional,
128        policy.tolerance,
129    )
130}
131
132pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
133    tokens: &[Token],
134    standalone: &S,
135    valued: &V,
136    bare: bool,
137    max_positional: Option<usize>,
138    tolerance: FlagTolerance,
139) -> bool {
140    if tokens.len() == 1 {
141        return bare;
142    }
143
144    let mut i = 1;
145    let mut positionals: usize = 0;
146    while i < tokens.len() {
147        let t = &tokens[i];
148
149        if *t == "--" {
150            positionals += tokens.len() - i - 1;
151            break;
152        }
153
154        if !t.starts_with('-') {
155            positionals += 1;
156            i += 1;
157            continue;
158        }
159
160        if tolerance.numeric_dash && t.len() > 1 && t[1..].bytes().all(|b| b.is_ascii_digit()) {
161            i += 1;
162            continue;
163        }
164
165        if standalone.contains_flag(t) {
166            i += 1;
167            continue;
168        }
169
170        if valued.contains_flag(t) {
171            i += 2;
172            continue;
173        }
174
175        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
176            if valued.contains_flag(flag) {
177                i += 1;
178                continue;
179            }
180            // `--foo=value` forms are governed by the long-flag tolerance.
181            if tolerance.unknown.allows_long() {
182                positionals += 1;
183                i += 1;
184                continue;
185            }
186            return false;
187        }
188
189        if t.starts_with("--") {
190            if tolerance.unknown.allows_long() {
191                positionals += 1;
192                i += 1;
193                continue;
194            }
195            return false;
196        }
197
198        let bytes = t.as_bytes();
199        let mut j = 1;
200        while j < bytes.len() {
201            let b = bytes[j];
202            let is_last = j == bytes.len() - 1;
203            if standalone.contains_short(b) {
204                j += 1;
205                continue;
206            }
207            if valued.contains_short(b) {
208                if is_last {
209                    i += 1;
210                }
211                break;
212            }
213            if tolerance.unknown.allows_short() {
214                positionals += 1;
215                break;
216            }
217            return false;
218        }
219        i += 1;
220    }
221    max_positional.is_none_or(|max| positionals <= max)
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    static TEST_POLICY: FlagPolicy = FlagPolicy {
229        standalone: WordSet::flags(&[
230            "--color", "--count", "--help", "--recursive", "--version",
231            "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
232        ]),
233        valued: WordSet::flags(&[
234            "--after-context", "--before-context", "--max-count",
235            "-A", "-B", "-m",
236        ]),
237        bare: false,
238        max_positional: None,
239        tolerance: FlagTolerance::strict(),
240    };
241
242    fn toks(words: &[&str]) -> Vec<Token> {
243        words.iter().map(|s| Token::from_test(s)).collect()
244    }
245
246    #[test]
247    fn bare_denied_when_bare_false() {
248        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
249    }
250
251    #[test]
252    fn bare_allowed_when_bare_true() {
253        let policy = FlagPolicy {
254            standalone: WordSet::flags(&[]),
255            valued: WordSet::flags(&[]),
256            bare: true,
257            max_positional: None,
258            tolerance: FlagTolerance::strict(),
259        };
260        assert!(check(&toks(&["uname"]), &policy));
261    }
262
263    #[test]
264    fn standalone_long_flag() {
265        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
266    }
267
268    #[test]
269    fn standalone_short_flag() {
270        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
271    }
272
273    #[test]
274    fn valued_long_flag_space() {
275        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
276    }
277
278    #[test]
279    fn valued_long_flag_eq() {
280        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
281    }
282
283    #[test]
284    fn valued_short_flag_space() {
285        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
286    }
287
288    #[test]
289    fn combined_standalone_short() {
290        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
291    }
292
293    #[test]
294    fn combined_short_with_valued_last() {
295        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
296    }
297
298    #[test]
299    fn combined_short_valued_mid_consumes_rest() {
300        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
301    }
302
303    #[test]
304    fn unknown_long_flag_denied() {
305        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
306    }
307
308    #[test]
309    fn unknown_short_flag_denied() {
310        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
311    }
312
313    #[test]
314    fn unknown_combined_short_denied() {
315        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
316    }
317
318    #[test]
319    fn unknown_long_eq_denied() {
320        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
321    }
322
323    #[test]
324    fn double_dash_stops_checking() {
325        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
326    }
327
328    #[test]
329    fn positional_args_allowed() {
330        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
331    }
332
333    #[test]
334    fn mixed_flags_and_positional() {
335        assert!(check(
336            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
337            &TEST_POLICY,
338        ));
339    }
340
341    #[test]
342    fn valued_short_in_explicit_form() {
343        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
344    }
345
346    #[test]
347    fn bare_dash_allowed_as_stdin() {
348        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
349    }
350
351    #[test]
352    fn valued_flag_at_end_without_value() {
353        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
354    }
355
356    #[test]
357    fn single_short_in_wordset_and_byte_array() {
358        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
359    }
360
361    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
362        standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
363        valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
364        bare: true,
365        max_positional: Some(1),
366        tolerance: FlagTolerance::strict(),
367    };
368
369    #[test]
370    fn max_positional_within_limit() {
371        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
372    }
373
374    #[test]
375    fn max_positional_exceeded() {
376        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
377    }
378
379    #[test]
380    fn max_positional_with_flags_within_limit() {
381        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
382    }
383
384    #[test]
385    fn max_positional_with_flags_exceeded() {
386        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
387    }
388
389    #[test]
390    fn max_positional_after_double_dash() {
391        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
392    }
393
394    #[test]
395    fn max_positional_bare_allowed() {
396        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
397    }
398
399    static BOTH_TOLERANCES_POLICY: FlagPolicy = FlagPolicy {
400        standalone: WordSet::flags(&["-E", "-e", "-n"]),
401        valued: WordSet::flags(&[]),
402        bare: true,
403        max_positional: None,
404        tolerance: FlagTolerance { unknown: UnknownTolerance::Both, numeric_dash: false },
405    };
406
407    #[test]
408    fn both_tolerances_accept_unknown_long() {
409        assert!(check(&toks(&["echo", "--unknown", "hello"]), &BOTH_TOLERANCES_POLICY));
410    }
411
412    #[test]
413    fn both_tolerances_accept_unknown_short() {
414        assert!(check(&toks(&["echo", "-x", "hello"]), &BOTH_TOLERANCES_POLICY));
415    }
416
417    #[test]
418    fn both_tolerances_accept_triple_dash() {
419        assert!(check(&toks(&["echo", "---"]), &BOTH_TOLERANCES_POLICY));
420    }
421
422    #[test]
423    fn both_tolerances_known_flags_still_work() {
424        assert!(check(&toks(&["echo", "-n", "hello"]), &BOTH_TOLERANCES_POLICY));
425    }
426
427    #[test]
428    fn both_tolerances_combo_known_short() {
429        assert!(check(&toks(&["echo", "-ne", "hello"]), &BOTH_TOLERANCES_POLICY));
430    }
431
432    #[test]
433    fn both_tolerances_combo_unknown_short_byte() {
434        assert!(check(&toks(&["echo", "-nx", "hello"]), &BOTH_TOLERANCES_POLICY));
435    }
436
437    #[test]
438    fn both_tolerances_unknown_eq_form() {
439        assert!(check(&toks(&["echo", "--foo=bar"]), &BOTH_TOLERANCES_POLICY));
440    }
441
442    // ============ Narrow tolerance: short-only ============
443    // tolerate_unknown_short = true accepts unknown single-dash tokens
444    // (-X, -mayDie, -help) as positional, while leaving double-dash unknowns
445    // strict. This is the safer setting because most modern destructive
446    // flags are double-dash.
447
448    static SHORT_ONLY_POLICY: FlagPolicy = FlagPolicy {
449        standalone: WordSet::flags(&["--help"]),
450        valued: WordSet::flags(&[]),
451        bare: false,
452        max_positional: None,
453        tolerance: FlagTolerance { unknown: UnknownTolerance::Short, numeric_dash: false },
454    };
455
456    #[test]
457    fn short_only_accepts_unknown_dash_letter() {
458        assert!(check(&toks(&["sample", "-mayDie"]), &SHORT_ONLY_POLICY));
459    }
460
461    #[test]
462    fn short_only_accepts_single_dash_long_word() {
463        // pdftotext-style: `-help`, `-layout`, `-version` (single dash + word)
464        assert!(check(&toks(&["pdftotext", "-layout"]), &SHORT_ONLY_POLICY));
465    }
466
467    #[test]
468    fn short_only_denies_unknown_double_dash() {
469        // The whole point of the narrow split: --evil-flag must not slip
470        // through when only short-tolerance is on.
471        assert!(!check(&toks(&["sample", "--evil-flag"]), &SHORT_ONLY_POLICY));
472    }
473
474    #[test]
475    fn short_only_denies_unknown_eq_form() {
476        assert!(!check(&toks(&["sample", "--evil=value"]), &SHORT_ONLY_POLICY));
477    }
478
479    #[test]
480    fn short_only_known_long_flag_still_works() {
481        assert!(check(&toks(&["sample", "--help"]), &SHORT_ONLY_POLICY));
482    }
483
484    // ============ Narrow tolerance: long-only ============
485    // tolerate_unknown_long = true accepts unknown double-dash tokens as
486    // positional. This is the dangerous form; reserved for tools like AWS
487    // CLI whose long-flag surface is genuinely unbounded.
488
489    static LONG_ONLY_POLICY: FlagPolicy = FlagPolicy {
490        standalone: WordSet::flags(&["--help"]),
491        valued: WordSet::flags(&[]),
492        bare: false,
493        max_positional: None,
494        tolerance: FlagTolerance { unknown: UnknownTolerance::Long, numeric_dash: false },
495    };
496
497    #[test]
498    fn long_only_accepts_unknown_double_dash() {
499        assert!(check(&toks(&["aws", "--some-aws-flag"]), &LONG_ONLY_POLICY));
500    }
501
502    #[test]
503    fn long_only_accepts_unknown_eq_form() {
504        assert!(check(
505            &toks(&["aws", "--filter=Name=tag,Values=foo"]),
506            &LONG_ONLY_POLICY,
507        ));
508    }
509
510    #[test]
511    fn long_only_denies_unknown_short_dash() {
512        assert!(!check(&toks(&["aws", "-x"]), &LONG_ONLY_POLICY));
513    }
514
515    // ============ Both tolerances false: strict ============
516
517    static STRICT_POLICY: FlagPolicy = FlagPolicy {
518        standalone: WordSet::flags(&["--help"]),
519        valued: WordSet::flags(&[]),
520        bare: false,
521        max_positional: None,
522        tolerance: FlagTolerance::strict(),
523    };
524
525    #[test]
526    fn strict_denies_unknown_short() {
527        assert!(!check(&toks(&["foo", "-evil"]), &STRICT_POLICY));
528    }
529
530    #[test]
531    fn strict_denies_unknown_long() {
532        assert!(!check(&toks(&["foo", "--evil"]), &STRICT_POLICY));
533    }
534
535    #[test]
536    fn strict_known_flag_passes() {
537        assert!(check(&toks(&["foo", "--help"]), &STRICT_POLICY));
538    }
539
540    #[test]
541    fn both_tolerances_with_max_positional() {
542        let policy = FlagPolicy {
543            standalone: WordSet::flags(&["-n"]),
544            valued: WordSet::flags(&[]),
545            bare: true,
546            max_positional: Some(2),
547            tolerance: FlagTolerance { unknown: UnknownTolerance::Both, numeric_dash: false },
548        };
549        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
550        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
551    }
552
553    static NUMERIC_DASH_POLICY: FlagPolicy = FlagPolicy {
554        standalone: WordSet::flags(&[
555            "--help", "--quiet", "--verbose", "--version",
556            "-V", "-h", "-q", "-v", "-z",
557        ]),
558        valued: WordSet::flags(&["--bytes", "--lines", "-c", "-n"]),
559        bare: true,
560        max_positional: None,
561        tolerance: FlagTolerance { numeric_dash: true, ..FlagTolerance::strict() },
562    };
563
564    #[test]
565    fn numeric_dash_single_digit() {
566        assert!(check(&toks(&["head", "-5"]), &NUMERIC_DASH_POLICY));
567    }
568
569    #[test]
570    fn numeric_dash_multi_digit() {
571        assert!(check(&toks(&["head", "-20"]), &NUMERIC_DASH_POLICY));
572    }
573
574    #[test]
575    fn numeric_dash_large_number() {
576        assert!(check(&toks(&["head", "-1000"]), &NUMERIC_DASH_POLICY));
577    }
578
579    #[test]
580    fn numeric_dash_with_file_arg() {
581        assert!(check(&toks(&["head", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
582    }
583
584    #[test]
585    fn numeric_dash_with_other_flags() {
586        assert!(check(&toks(&["head", "-q", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
587    }
588
589    #[test]
590    fn numeric_dash_zero() {
591        assert!(check(&toks(&["head", "-0"]), &NUMERIC_DASH_POLICY));
592    }
593
594    #[test]
595    fn numeric_dash_still_rejects_unknown_flags() {
596        assert!(!check(&toks(&["head", "-x"]), &NUMERIC_DASH_POLICY));
597    }
598
599    #[test]
600    fn numeric_dash_rejects_mixed_alpha_num() {
601        assert!(!check(&toks(&["head", "-20x"]), &NUMERIC_DASH_POLICY));
602    }
603
604    #[test]
605    fn numeric_dash_disabled_rejects_multi_digit() {
606        assert!(!check(&toks(&["grep", "-20", "pattern"]), &TEST_POLICY));
607    }
608}