Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, PartialEq, Eq)]
4pub enum FlagStyle {
5    Strict,
6    Positional,
7}
8
9pub struct FlagPolicy {
10    pub standalone: WordSet,
11    pub standalone_short: &'static [u8],
12    pub valued: WordSet,
13    pub valued_short: &'static [u8],
14    pub bare: bool,
15    pub max_positional: Option<usize>,
16    pub flag_style: FlagStyle,
17}
18
19impl FlagPolicy {
20    pub fn describe(&self) -> String {
21        use crate::docs::{wordset_items, DocBuilder};
22        let mut builder = DocBuilder::new();
23        let standalone = wordset_items(&self.standalone);
24        if !standalone.is_empty() {
25            builder = builder.section(format!("Allowed standalone flags: {standalone}."));
26        }
27        let valued = wordset_items(&self.valued);
28        if !valued.is_empty() {
29            builder = builder.section(format!("Allowed valued flags: {valued}."));
30        }
31        if self.bare {
32            builder = builder.section("Bare invocation allowed.".to_string());
33        }
34        if self.flag_style == FlagStyle::Positional {
35            builder = builder.section("Hyphen-prefixed positional arguments accepted.".to_string());
36        }
37        let result = builder.build();
38        if result.is_empty() && !self.bare {
39            return "Positional arguments only.".to_string();
40        }
41        result
42    }
43
44    pub fn flag_summary(&self) -> String {
45        use crate::docs::wordset_items;
46        let mut parts = Vec::new();
47        let standalone = wordset_items(&self.standalone);
48        if !standalone.is_empty() {
49            parts.push(format!("Flags: {standalone}"));
50        }
51        let valued = wordset_items(&self.valued);
52        if !valued.is_empty() {
53            parts.push(format!("Valued: {valued}"));
54        }
55        if self.flag_style == FlagStyle::Positional {
56            parts.push("Positional args accepted".to_string());
57        }
58        parts.join(". ")
59    }
60}
61
62pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
63    if tokens.len() == 1 {
64        return policy.bare;
65    }
66
67    let mut i = 1;
68    let mut positionals: usize = 0;
69    while i < tokens.len() {
70        let t = &tokens[i];
71
72        if *t == "--" {
73            positionals += tokens.len() - i - 1;
74            break;
75        }
76
77        if !t.starts_with('-') {
78            positionals += 1;
79            i += 1;
80            continue;
81        }
82
83        if policy.standalone.contains(t) {
84            i += 1;
85            continue;
86        }
87
88        if policy.valued.contains(t) {
89            i += 2;
90            continue;
91        }
92
93        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
94            if policy.valued.contains(flag) {
95                i += 1;
96                continue;
97            }
98            if policy.flag_style == FlagStyle::Positional {
99                positionals += 1;
100                i += 1;
101                continue;
102            }
103            return false;
104        }
105
106        if t.starts_with("--") {
107            if policy.flag_style == FlagStyle::Positional {
108                positionals += 1;
109                i += 1;
110                continue;
111            }
112            return false;
113        }
114
115        let bytes = t.as_bytes();
116        let mut j = 1;
117        while j < bytes.len() {
118            let b = bytes[j];
119            let is_last = j == bytes.len() - 1;
120            if policy.standalone_short.contains(&b) {
121                j += 1;
122                continue;
123            }
124            if policy.valued_short.contains(&b) {
125                if is_last {
126                    i += 1;
127                }
128                break;
129            }
130            if policy.flag_style == FlagStyle::Positional {
131                positionals += 1;
132                break;
133            }
134            return false;
135        }
136        i += 1;
137    }
138    policy.max_positional.is_none_or(|max| positionals <= max)
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    static TEST_POLICY: FlagPolicy = FlagPolicy {
146        standalone: WordSet::new(&[
147            "--color", "--count", "--help", "--recursive", "--version",
148            "-c", "-r",
149        ]),
150        standalone_short: b"cHilnorsvw",
151        valued: WordSet::new(&[
152            "--after-context", "--before-context", "--max-count",
153            "-A", "-B", "-m",
154        ]),
155        valued_short: b"ABm",
156        bare: false,
157        max_positional: None,
158        flag_style: FlagStyle::Strict,
159    };
160
161    fn toks(words: &[&str]) -> Vec<Token> {
162        words.iter().map(|s| Token::from_test(s)).collect()
163    }
164
165    #[test]
166    fn bare_denied_when_bare_false() {
167        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
168    }
169
170    #[test]
171    fn bare_allowed_when_bare_true() {
172        let policy = FlagPolicy {
173            standalone: WordSet::new(&[]),
174            standalone_short: b"",
175            valued: WordSet::new(&[]),
176            valued_short: b"",
177            bare: true,
178            max_positional: None,
179            flag_style: FlagStyle::Strict,
180        };
181        assert!(check(&toks(&["uname"]), &policy));
182    }
183
184    #[test]
185    fn standalone_long_flag() {
186        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
187    }
188
189    #[test]
190    fn standalone_short_flag() {
191        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
192    }
193
194    #[test]
195    fn valued_long_flag_space() {
196        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
197    }
198
199    #[test]
200    fn valued_long_flag_eq() {
201        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
202    }
203
204    #[test]
205    fn valued_short_flag_space() {
206        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
207    }
208
209    #[test]
210    fn combined_standalone_short() {
211        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
212    }
213
214    #[test]
215    fn combined_short_with_valued_last() {
216        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
217    }
218
219    #[test]
220    fn combined_short_valued_mid_consumes_rest() {
221        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
222    }
223
224    #[test]
225    fn unknown_long_flag_denied() {
226        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
227    }
228
229    #[test]
230    fn unknown_short_flag_denied() {
231        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
232    }
233
234    #[test]
235    fn unknown_combined_short_denied() {
236        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
237    }
238
239    #[test]
240    fn unknown_long_eq_denied() {
241        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
242    }
243
244    #[test]
245    fn double_dash_stops_checking() {
246        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
247    }
248
249    #[test]
250    fn positional_args_allowed() {
251        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
252    }
253
254    #[test]
255    fn mixed_flags_and_positional() {
256        assert!(check(
257            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
258            &TEST_POLICY,
259        ));
260    }
261
262    #[test]
263    fn valued_short_in_explicit_form() {
264        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
265    }
266
267    #[test]
268    fn bare_dash_allowed_as_stdin() {
269        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
270    }
271
272    #[test]
273    fn valued_flag_at_end_without_value() {
274        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
275    }
276
277    #[test]
278    fn single_short_in_wordset_and_byte_array() {
279        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
280    }
281
282    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
283        standalone: WordSet::new(&["--count", "-c", "-d", "-i", "-u"]),
284        standalone_short: b"cdiu",
285        valued: WordSet::new(&["--skip-fields", "-f", "-s"]),
286        valued_short: b"fs",
287        bare: true,
288        max_positional: Some(1),
289        flag_style: FlagStyle::Strict,
290    };
291
292    #[test]
293    fn max_positional_within_limit() {
294        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
295    }
296
297    #[test]
298    fn max_positional_exceeded() {
299        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
300    }
301
302    #[test]
303    fn max_positional_with_flags_within_limit() {
304        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
305    }
306
307    #[test]
308    fn max_positional_with_flags_exceeded() {
309        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
310    }
311
312    #[test]
313    fn max_positional_after_double_dash() {
314        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
315    }
316
317    #[test]
318    fn max_positional_bare_allowed() {
319        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
320    }
321
322    static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
323        standalone: WordSet::new(&["-E", "-e", "-n"]),
324        standalone_short: b"Een",
325        valued: WordSet::new(&[]),
326        valued_short: b"",
327        bare: true,
328        max_positional: None,
329        flag_style: FlagStyle::Positional,
330    };
331
332    #[test]
333    fn positional_style_unknown_long() {
334        assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
335    }
336
337    #[test]
338    fn positional_style_unknown_short() {
339        assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
340    }
341
342    #[test]
343    fn positional_style_dashes() {
344        assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
345    }
346
347    #[test]
348    fn positional_style_known_flags_still_work() {
349        assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
350    }
351
352    #[test]
353    fn positional_style_combo_known() {
354        assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
355    }
356
357    #[test]
358    fn positional_style_combo_unknown_byte() {
359        assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
360    }
361
362    #[test]
363    fn positional_style_unknown_eq() {
364        assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
365    }
366
367    #[test]
368    fn positional_style_with_max_positional() {
369        let policy = FlagPolicy {
370            standalone: WordSet::new(&["-n"]),
371            standalone_short: b"n",
372            valued: WordSet::new(&[]),
373            valued_short: b"",
374            bare: true,
375            max_positional: Some(2),
376            flag_style: FlagStyle::Positional,
377        };
378        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
379        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
380    }
381}