Skip to main content

safe_chains/
policy.rs

1use crate::parse::{Token, WordSet};
2
3pub struct FlagPolicy {
4    pub standalone: WordSet,
5    pub standalone_short: &'static [u8],
6    pub valued: WordSet,
7    pub valued_short: &'static [u8],
8    pub bare: bool,
9    pub max_positional: Option<usize>,
10}
11
12impl FlagPolicy {
13    pub fn describe(&self) -> String {
14        use crate::docs::{wordset_items, DocBuilder};
15        let mut builder = DocBuilder::new();
16        let standalone = wordset_items(&self.standalone);
17        if !standalone.is_empty() {
18            builder = builder.section(format!("Allowed standalone flags: {standalone}."));
19        }
20        let valued = wordset_items(&self.valued);
21        if !valued.is_empty() {
22            builder = builder.section(format!("Allowed valued flags: {valued}."));
23        }
24        if self.bare {
25            builder = builder.section("Bare invocation allowed.".to_string());
26        }
27        builder.build()
28    }
29}
30
31pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
32    if tokens.len() == 1 {
33        return policy.bare;
34    }
35
36    let mut i = 1;
37    let mut positionals: usize = 0;
38    while i < tokens.len() {
39        let t = &tokens[i];
40
41        if *t == "--" {
42            positionals += tokens.len() - i - 1;
43            break;
44        }
45
46        if !t.starts_with('-') {
47            positionals += 1;
48            i += 1;
49            continue;
50        }
51
52        if policy.standalone.contains(t) {
53            i += 1;
54            continue;
55        }
56
57        if policy.valued.contains(t) {
58            i += 2;
59            continue;
60        }
61
62        if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
63            if policy.valued.contains(flag) {
64                i += 1;
65                continue;
66            }
67            return false;
68        }
69
70        if t.starts_with("--") {
71            return false;
72        }
73
74        let bytes = t.as_bytes();
75        let mut j = 1;
76        while j < bytes.len() {
77            let b = bytes[j];
78            let is_last = j == bytes.len() - 1;
79            if policy.standalone_short.contains(&b) {
80                j += 1;
81                continue;
82            }
83            if policy.valued_short.contains(&b) {
84                if is_last {
85                    i += 1;
86                }
87                break;
88            }
89            return false;
90        }
91        i += 1;
92    }
93    policy.max_positional.is_none_or(|max| positionals <= max)
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    static TEST_POLICY: FlagPolicy = FlagPolicy {
101        standalone: WordSet::new(&[
102            "--color", "--count", "--help", "--recursive", "--version",
103            "-c", "-r",
104        ]),
105        standalone_short: b"cHilnorsvw",
106        valued: WordSet::new(&[
107            "--after-context", "--before-context", "--max-count",
108            "-A", "-B", "-m",
109        ]),
110        valued_short: b"ABm",
111        bare: false,
112        max_positional: None,
113    };
114
115    fn toks(words: &[&str]) -> Vec<Token> {
116        words.iter().map(|s| Token::from_test(s)).collect()
117    }
118
119    #[test]
120    fn bare_denied_when_bare_false() {
121        assert!(!check(&toks(&["grep"]), &TEST_POLICY));
122    }
123
124    #[test]
125    fn bare_allowed_when_bare_true() {
126        let policy = FlagPolicy {
127            standalone: WordSet::new(&[]),
128            standalone_short: b"",
129            valued: WordSet::new(&[]),
130            valued_short: b"",
131            bare: true,
132            max_positional: None,
133        };
134        assert!(check(&toks(&["uname"]), &policy));
135    }
136
137    #[test]
138    fn standalone_long_flag() {
139        assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
140    }
141
142    #[test]
143    fn standalone_short_flag() {
144        assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
145    }
146
147    #[test]
148    fn valued_long_flag_space() {
149        assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
150    }
151
152    #[test]
153    fn valued_long_flag_eq() {
154        assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
155    }
156
157    #[test]
158    fn valued_short_flag_space() {
159        assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
160    }
161
162    #[test]
163    fn combined_standalone_short() {
164        assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
165    }
166
167    #[test]
168    fn combined_short_with_valued_last() {
169        assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
170    }
171
172    #[test]
173    fn combined_short_valued_mid_consumes_rest() {
174        assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
175    }
176
177    #[test]
178    fn unknown_long_flag_denied() {
179        assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
180    }
181
182    #[test]
183    fn unknown_short_flag_denied() {
184        assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
185    }
186
187    #[test]
188    fn unknown_combined_short_denied() {
189        assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
190    }
191
192    #[test]
193    fn unknown_long_eq_denied() {
194        assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
195    }
196
197    #[test]
198    fn double_dash_stops_checking() {
199        assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
200    }
201
202    #[test]
203    fn positional_args_allowed() {
204        assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
205    }
206
207    #[test]
208    fn mixed_flags_and_positional() {
209        assert!(check(
210            &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
211            &TEST_POLICY,
212        ));
213    }
214
215    #[test]
216    fn valued_short_in_explicit_form() {
217        assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
218    }
219
220    #[test]
221    fn bare_dash_allowed_as_stdin() {
222        assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
223    }
224
225    #[test]
226    fn valued_flag_at_end_without_value() {
227        assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
228    }
229
230    #[test]
231    fn single_short_in_wordset_and_byte_array() {
232        assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
233    }
234
235    static LIMITED_POLICY: FlagPolicy = FlagPolicy {
236        standalone: WordSet::new(&["--count", "-c", "-d", "-i", "-u"]),
237        standalone_short: b"cdiu",
238        valued: WordSet::new(&["--skip-fields", "-f", "-s"]),
239        valued_short: b"fs",
240        bare: true,
241        max_positional: Some(1),
242    };
243
244    #[test]
245    fn max_positional_within_limit() {
246        assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
247    }
248
249    #[test]
250    fn max_positional_exceeded() {
251        assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
252    }
253
254    #[test]
255    fn max_positional_with_flags_within_limit() {
256        assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
257    }
258
259    #[test]
260    fn max_positional_with_flags_exceeded() {
261        assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
262    }
263
264    #[test]
265    fn max_positional_after_double_dash() {
266        assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
267    }
268
269    #[test]
270    fn max_positional_bare_allowed() {
271        assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
272    }
273}