Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, PartialEq, Eq)]
4pub enum FlagStyle {
5    Strict,
6    Positional,
7}
8
9pub struct FlagPolicy {
10    pub standalone: WordSet,
11    pub standalone_short: &'static [u8],
12    pub valued: WordSet,
13    pub valued_short: &'static [u8],
14    pub bare: bool,
15    pub max_positional: Option<usize>,
16    pub flag_style: FlagStyle,
17}
18
19impl FlagPolicy {
20    pub fn describe(&self) -> String {
21        use crate::docs::wordset_items;
22        let mut lines = Vec::new();
23        let standalone = wordset_items(&self.standalone);
24        if !standalone.is_empty() {
25            lines.push(format!("- Allowed standalone flags: {standalone}"));
26        }
27        let valued = wordset_items(&self.valued);
28        if !valued.is_empty() {
29            lines.push(format!("- Allowed valued flags: {valued}"));
30        }
31        if self.bare {
32            lines.push("- Bare invocation allowed".to_string());
33        }
34        if self.flag_style == FlagStyle::Positional {
35            lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
36        }
37        if lines.is_empty() && !self.bare {
38            return "- Positional arguments only".to_string();
39        }
40        lines.join("\n")
41    }
42
43    pub fn flag_summary(&self) -> String {
44        use crate::docs::wordset_items;
45        let mut parts = Vec::new();
46        let standalone = wordset_items(&self.standalone);
47        if !standalone.is_empty() {
48            parts.push(format!("Flags: {standalone}"));
49        }
50        let valued = wordset_items(&self.valued);
51        if !valued.is_empty() {
52            parts.push(format!("Valued: {valued}"));
53        }
54        if self.flag_style == FlagStyle::Positional {
55            parts.push("Positional args accepted".to_string());
56        }
57        parts.join(". ")
58    }
59}
60
61pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
62    if tokens.len() == 1 {
63        return policy.bare;
64    }
65
66    let mut i = 1;
67    let mut positionals: usize = 0;
68    while i < tokens.len() {
69        let t = &tokens[i];
70
71        if *t == "--" {
72            positionals += tokens.len() - i - 1;
73            break;
74        }
75
76        if !t.starts_with('-') {
77            positionals += 1;
78            i += 1;
79            continue;
80        }
81
82        if policy.standalone.contains(t) {
83            i += 1;
84            continue;
85        }
86
87        if policy.valued.contains(t) {
88            i += 2;
89            continue;
90        }
91
92        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
93            if policy.valued.contains(flag) {
94                i += 1;
95                continue;
96            }
97            if policy.flag_style == FlagStyle::Positional {
98                positionals += 1;
99                i += 1;
100                continue;
101            }
102            return false;
103        }
104
105        if t.starts_with("--") {
106            if policy.flag_style == FlagStyle::Positional {
107                positionals += 1;
108                i += 1;
109                continue;
110            }
111            return false;
112        }
113
114        let bytes = t.as_bytes();
115        let mut j = 1;
116        while j < bytes.len() {
117            let b = bytes[j];
118            let is_last = j == bytes.len() - 1;
119            if policy.standalone_short.contains(&b) {
120                j += 1;
121                continue;
122            }
123            if policy.valued_short.contains(&b) {
124                if is_last {
125                    i += 1;
126                }
127                break;
128            }
129            if policy.flag_style == FlagStyle::Positional {
130                positionals += 1;
131                break;
132            }
133            return false;
134        }
135        i += 1;
136    }
137    policy.max_positional.is_none_or(|max| positionals <= max)
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    static TEST_POLICY: FlagPolicy = FlagPolicy {
145        standalone: WordSet::new(&[
146            "--color", "--count", "--help", "--recursive", "--version",
147            "-c", "-r",
148        ]),
149        standalone_short: b"cHilnorsvw",
150        valued: WordSet::new(&[
151            "--after-context", "--before-context", "--max-count",
152            "-A", "-B", "-m",
153        ]),
154        valued_short: b"ABm",
155        bare: false,
156        max_positional: None,
157        flag_style: FlagStyle::Strict,
158    };
159
160    fn toks(words: &[&str]) -> Vec<Token> {
161        words.iter().map(|s| Token::from_test(s)).collect()
162    }
163
164    #[test]
165    fn bare_denied_when_bare_false() {
166        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
167    }
168
169    #[test]
170    fn bare_allowed_when_bare_true() {
171        let policy = FlagPolicy {
172            standalone: WordSet::new(&[]),
173            standalone_short: b"",
174            valued: WordSet::new(&[]),
175            valued_short: b"",
176            bare: true,
177            max_positional: None,
178            flag_style: FlagStyle::Strict,
179        };
180        assert!(check(&toks(&["uname"]), &policy));
181    }
182
183    #[test]
184    fn standalone_long_flag() {
185        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
186    }
187
188    #[test]
189    fn standalone_short_flag() {
190        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
191    }
192
193    #[test]
194    fn valued_long_flag_space() {
195        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
196    }
197
198    #[test]
199    fn valued_long_flag_eq() {
200        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
201    }
202
203    #[test]
204    fn valued_short_flag_space() {
205        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
206    }
207
208    #[test]
209    fn combined_standalone_short() {
210        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
211    }
212
213    #[test]
214    fn combined_short_with_valued_last() {
215        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
216    }
217
218    #[test]
219    fn combined_short_valued_mid_consumes_rest() {
220        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
221    }
222
223    #[test]
224    fn unknown_long_flag_denied() {
225        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
226    }
227
228    #[test]
229    fn unknown_short_flag_denied() {
230        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
231    }
232
233    #[test]
234    fn unknown_combined_short_denied() {
235        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
236    }
237
238    #[test]
239    fn unknown_long_eq_denied() {
240        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
241    }
242
243    #[test]
244    fn double_dash_stops_checking() {
245        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
246    }
247
248    #[test]
249    fn positional_args_allowed() {
250        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
251    }
252
253    #[test]
254    fn mixed_flags_and_positional() {
255        assert!(check(
256            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
257            &TEST_POLICY,
258        ));
259    }
260
261    #[test]
262    fn valued_short_in_explicit_form() {
263        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
264    }
265
266    #[test]
267    fn bare_dash_allowed_as_stdin() {
268        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
269    }
270
271    #[test]
272    fn valued_flag_at_end_without_value() {
273        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
274    }
275
276    #[test]
277    fn single_short_in_wordset_and_byte_array() {
278        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
279    }
280
281    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
282        standalone: WordSet::new(&["--count", "-c", "-d", "-i", "-u"]),
283        standalone_short: b"cdiu",
284        valued: WordSet::new(&["--skip-fields", "-f", "-s"]),
285        valued_short: b"fs",
286        bare: true,
287        max_positional: Some(1),
288        flag_style: FlagStyle::Strict,
289    };
290
291    #[test]
292    fn max_positional_within_limit() {
293        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
294    }
295
296    #[test]
297    fn max_positional_exceeded() {
298        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
299    }
300
301    #[test]
302    fn max_positional_with_flags_within_limit() {
303        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
304    }
305
306    #[test]
307    fn max_positional_with_flags_exceeded() {
308        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
309    }
310
311    #[test]
312    fn max_positional_after_double_dash() {
313        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
314    }
315
316    #[test]
317    fn max_positional_bare_allowed() {
318        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
319    }
320
321    static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
322        standalone: WordSet::new(&["-E", "-e", "-n"]),
323        standalone_short: b"Een",
324        valued: WordSet::new(&[]),
325        valued_short: b"",
326        bare: true,
327        max_positional: None,
328        flag_style: FlagStyle::Positional,
329    };
330
331    #[test]
332    fn positional_style_unknown_long() {
333        assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
334    }
335
336    #[test]
337    fn positional_style_unknown_short() {
338        assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
339    }
340
341    #[test]
342    fn positional_style_dashes() {
343        assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
344    }
345
346    #[test]
347    fn positional_style_known_flags_still_work() {
348        assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
349    }
350
351    #[test]
352    fn positional_style_combo_known() {
353        assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
354    }
355
356    #[test]
357    fn positional_style_combo_unknown_byte() {
358        assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
359    }
360
361    #[test]
362    fn positional_style_unknown_eq() {
363        assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
364    }
365
366    #[test]
367    fn positional_style_with_max_positional() {
368        let policy = FlagPolicy {
369            standalone: WordSet::new(&["-n"]),
370            standalone_short: b"n",
371            valued: WordSet::new(&[]),
372            valued_short: b"",
373            bare: true,
374            max_positional: Some(2),
375            flag_style: FlagStyle::Positional,
376        };
377        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
378        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
379    }
380}