Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, PartialEq, Eq)]
4pub enum FlagStyle {
5    Strict,
6    Positional,
7}
8
9pub struct FlagPolicy {
10    pub standalone: WordSet,
11    pub standalone_short: &'static [u8],
12    pub valued: WordSet,
13    pub valued_short: &'static [u8],
14    pub bare: bool,
15    pub max_positional: Option<usize>,
16    pub flag_style: FlagStyle,
17}
18
19impl FlagPolicy {
20    pub fn describe(&self) -> String {
21        use crate::docs::{wordset_items, DocBuilder};
22        let mut builder = DocBuilder::new();
23        let standalone = wordset_items(&self.standalone);
24        if !standalone.is_empty() {
25            builder = builder.section(format!("Allowed standalone flags: {standalone}."));
26        }
27        let valued = wordset_items(&self.valued);
28        if !valued.is_empty() {
29            builder = builder.section(format!("Allowed valued flags: {valued}."));
30        }
31        if self.bare {
32            builder = builder.section("Bare invocation allowed.".to_string());
33        }
34        if self.flag_style == FlagStyle::Positional {
35            builder = builder.section("Hyphen-prefixed positional arguments accepted.".to_string());
36        }
37        builder.build()
38    }
39
40    pub fn flag_summary(&self) -> String {
41        use crate::docs::wordset_items;
42        let mut parts = Vec::new();
43        let standalone = wordset_items(&self.standalone);
44        if !standalone.is_empty() {
45            parts.push(format!("Flags: {standalone}"));
46        }
47        let valued = wordset_items(&self.valued);
48        if !valued.is_empty() {
49            parts.push(format!("Valued: {valued}"));
50        }
51        if self.flag_style == FlagStyle::Positional {
52            parts.push("Positional args accepted".to_string());
53        }
54        parts.join(". ")
55    }
56}
57
58pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
59    if tokens.len() == 1 {
60        return policy.bare;
61    }
62
63    let mut i = 1;
64    let mut positionals: usize = 0;
65    while i < tokens.len() {
66        let t = &tokens[i];
67
68        if *t == "--" {
69            positionals += tokens.len() - i - 1;
70            break;
71        }
72
73        if !t.starts_with('-') {
74            positionals += 1;
75            i += 1;
76            continue;
77        }
78
79        if policy.standalone.contains(t) {
80            i += 1;
81            continue;
82        }
83
84        if policy.valued.contains(t) {
85            i += 2;
86            continue;
87        }
88
89        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
90            if policy.valued.contains(flag) {
91                i += 1;
92                continue;
93            }
94            if policy.flag_style == FlagStyle::Positional {
95                positionals += 1;
96                i += 1;
97                continue;
98            }
99            return false;
100        }
101
102        if t.starts_with("--") {
103            if policy.flag_style == FlagStyle::Positional {
104                positionals += 1;
105                i += 1;
106                continue;
107            }
108            return false;
109        }
110
111        let bytes = t.as_bytes();
112        let mut j = 1;
113        while j < bytes.len() {
114            let b = bytes[j];
115            let is_last = j == bytes.len() - 1;
116            if policy.standalone_short.contains(&b) {
117                j += 1;
118                continue;
119            }
120            if policy.valued_short.contains(&b) {
121                if is_last {
122                    i += 1;
123                }
124                break;
125            }
126            if policy.flag_style == FlagStyle::Positional {
127                positionals += 1;
128                break;
129            }
130            return false;
131        }
132        i += 1;
133    }
134    policy.max_positional.is_none_or(|max| positionals <= max)
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    static TEST_POLICY: FlagPolicy = FlagPolicy {
142        standalone: WordSet::new(&[
143            "--color", "--count", "--help", "--recursive", "--version",
144            "-c", "-r",
145        ]),
146        standalone_short: b"cHilnorsvw",
147        valued: WordSet::new(&[
148            "--after-context", "--before-context", "--max-count",
149            "-A", "-B", "-m",
150        ]),
151        valued_short: b"ABm",
152        bare: false,
153        max_positional: None,
154        flag_style: FlagStyle::Strict,
155    };
156
157    fn toks(words: &[&str]) -> Vec<Token> {
158        words.iter().map(|s| Token::from_test(s)).collect()
159    }
160
161    #[test]
162    fn bare_denied_when_bare_false() {
163        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
164    }
165
166    #[test]
167    fn bare_allowed_when_bare_true() {
168        let policy = FlagPolicy {
169            standalone: WordSet::new(&[]),
170            standalone_short: b"",
171            valued: WordSet::new(&[]),
172            valued_short: b"",
173            bare: true,
174            max_positional: None,
175            flag_style: FlagStyle::Strict,
176        };
177        assert!(check(&toks(&["uname"]), &policy));
178    }
179
180    #[test]
181    fn standalone_long_flag() {
182        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
183    }
184
185    #[test]
186    fn standalone_short_flag() {
187        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
188    }
189
190    #[test]
191    fn valued_long_flag_space() {
192        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
193    }
194
195    #[test]
196    fn valued_long_flag_eq() {
197        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
198    }
199
200    #[test]
201    fn valued_short_flag_space() {
202        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
203    }
204
205    #[test]
206    fn combined_standalone_short() {
207        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
208    }
209
210    #[test]
211    fn combined_short_with_valued_last() {
212        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
213    }
214
215    #[test]
216    fn combined_short_valued_mid_consumes_rest() {
217        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
218    }
219
220    #[test]
221    fn unknown_long_flag_denied() {
222        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
223    }
224
225    #[test]
226    fn unknown_short_flag_denied() {
227        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
228    }
229
230    #[test]
231    fn unknown_combined_short_denied() {
232        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
233    }
234
235    #[test]
236    fn unknown_long_eq_denied() {
237        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
238    }
239
240    #[test]
241    fn double_dash_stops_checking() {
242        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
243    }
244
245    #[test]
246    fn positional_args_allowed() {
247        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
248    }
249
250    #[test]
251    fn mixed_flags_and_positional() {
252        assert!(check(
253            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
254            &TEST_POLICY,
255        ));
256    }
257
258    #[test]
259    fn valued_short_in_explicit_form() {
260        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
261    }
262
263    #[test]
264    fn bare_dash_allowed_as_stdin() {
265        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
266    }
267
268    #[test]
269    fn valued_flag_at_end_without_value() {
270        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
271    }
272
273    #[test]
274    fn single_short_in_wordset_and_byte_array() {
275        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
276    }
277
278    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
279        standalone: WordSet::new(&["--count", "-c", "-d", "-i", "-u"]),
280        standalone_short: b"cdiu",
281        valued: WordSet::new(&["--skip-fields", "-f", "-s"]),
282        valued_short: b"fs",
283        bare: true,
284        max_positional: Some(1),
285        flag_style: FlagStyle::Strict,
286    };
287
288    #[test]
289    fn max_positional_within_limit() {
290        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
291    }
292
293    #[test]
294    fn max_positional_exceeded() {
295        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
296    }
297
298    #[test]
299    fn max_positional_with_flags_within_limit() {
300        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
301    }
302
303    #[test]
304    fn max_positional_with_flags_exceeded() {
305        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
306    }
307
308    #[test]
309    fn max_positional_after_double_dash() {
310        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
311    }
312
313    #[test]
314    fn max_positional_bare_allowed() {
315        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
316    }
317
318    static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
319        standalone: WordSet::new(&["-E", "-e", "-n"]),
320        standalone_short: b"Een",
321        valued: WordSet::new(&[]),
322        valued_short: b"",
323        bare: true,
324        max_positional: None,
325        flag_style: FlagStyle::Positional,
326    };
327
328    #[test]
329    fn positional_style_unknown_long() {
330        assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
331    }
332
333    #[test]
334    fn positional_style_unknown_short() {
335        assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
336    }
337
338    #[test]
339    fn positional_style_dashes() {
340        assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
341    }
342
343    #[test]
344    fn positional_style_known_flags_still_work() {
345        assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
346    }
347
348    #[test]
349    fn positional_style_combo_known() {
350        assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
351    }
352
353    #[test]
354    fn positional_style_combo_unknown_byte() {
355        assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
356    }
357
358    #[test]
359    fn positional_style_unknown_eq() {
360        assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
361    }
362
363    #[test]
364    fn positional_style_with_max_positional() {
365        let policy = FlagPolicy {
366            standalone: WordSet::new(&["-n"]),
367            standalone_short: b"n",
368            valued: WordSet::new(&[]),
369            valued_short: b"",
370            bare: true,
371            max_positional: Some(2),
372            flag_style: FlagStyle::Positional,
373        };
374        assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
375        assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
376    }
377}