1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub enum FlagStyle {
5 Strict,
6 Positional,
7}
8
9pub trait FlagSet {
10 fn contains_flag(&self, token: &str) -> bool;
11 fn contains_short(&self, byte: u8) -> bool;
12}
13
14impl FlagSet for WordSet {
15 fn contains_flag(&self, token: &str) -> bool {
16 self.contains(token)
17 }
18 fn contains_short(&self, byte: u8) -> bool {
19 self.contains_short(byte)
20 }
21}
22
23impl FlagSet for [String] {
24 fn contains_flag(&self, token: &str) -> bool {
25 self.iter().any(|f| f.as_str() == token)
26 }
27 fn contains_short(&self, byte: u8) -> bool {
28 self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
29 }
30}
31
32impl FlagSet for Vec<String> {
33 fn contains_flag(&self, token: &str) -> bool {
34 self.as_slice().contains_flag(token)
35 }
36 fn contains_short(&self, byte: u8) -> bool {
37 self.as_slice().contains_short(byte)
38 }
39}
40
41pub struct FlagPolicy {
42 pub standalone: WordSet,
43 pub valued: WordSet,
44 pub bare: bool,
45 pub max_positional: Option<usize>,
46 pub flag_style: FlagStyle,
47 pub numeric_dash: bool,
48}
49
50impl FlagPolicy {
51 pub fn describe(&self) -> String {
52 use crate::docs::wordset_items;
53 let mut lines = Vec::new();
54 let standalone = wordset_items(&self.standalone);
55 if !standalone.is_empty() {
56 lines.push(format!("- Allowed standalone flags: {standalone}"));
57 }
58 let valued = wordset_items(&self.valued);
59 if !valued.is_empty() {
60 lines.push(format!("- Allowed valued flags: {valued}"));
61 }
62 if self.bare {
63 lines.push("- Bare invocation allowed".to_string());
64 }
65 if self.flag_style == FlagStyle::Positional {
66 lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
67 }
68 if self.numeric_dash {
69 lines.push("- Numeric shorthand accepted (e.g. -20 for -n 20)".to_string());
70 }
71 if lines.is_empty() && !self.bare {
72 return "- Positional arguments only".to_string();
73 }
74 lines.join("\n")
75 }
76
77}
78
79pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
80 check_flags(
81 tokens,
82 &policy.standalone,
83 &policy.valued,
84 policy.bare,
85 policy.max_positional,
86 policy.flag_style,
87 policy.numeric_dash,
88 )
89}
90
91pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
92 tokens: &[Token],
93 standalone: &S,
94 valued: &V,
95 bare: bool,
96 max_positional: Option<usize>,
97 flag_style: FlagStyle,
98 numeric_dash: bool,
99) -> bool {
100 if tokens.len() == 1 {
101 return bare;
102 }
103
104 let mut i = 1;
105 let mut positionals: usize = 0;
106 while i < tokens.len() {
107 let t = &tokens[i];
108
109 if *t == "--" {
110 positionals += tokens.len() - i - 1;
111 break;
112 }
113
114 if !t.starts_with('-') {
115 positionals += 1;
116 i += 1;
117 continue;
118 }
119
120 if numeric_dash && t.len() > 1 && t[1..].bytes().all(|b| b.is_ascii_digit()) {
121 i += 1;
122 continue;
123 }
124
125 if standalone.contains_flag(t) {
126 i += 1;
127 continue;
128 }
129
130 if valued.contains_flag(t) {
131 i += 2;
132 continue;
133 }
134
135 if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
136 if valued.contains_flag(flag) {
137 i += 1;
138 continue;
139 }
140 if flag_style == FlagStyle::Positional {
141 positionals += 1;
142 i += 1;
143 continue;
144 }
145 return false;
146 }
147
148 if t.starts_with("--") {
149 if flag_style == FlagStyle::Positional {
150 positionals += 1;
151 i += 1;
152 continue;
153 }
154 return false;
155 }
156
157 let bytes = t.as_bytes();
158 let mut j = 1;
159 while j < bytes.len() {
160 let b = bytes[j];
161 let is_last = j == bytes.len() - 1;
162 if standalone.contains_short(b) {
163 j += 1;
164 continue;
165 }
166 if valued.contains_short(b) {
167 if is_last {
168 i += 1;
169 }
170 break;
171 }
172 if flag_style == FlagStyle::Positional {
173 positionals += 1;
174 break;
175 }
176 return false;
177 }
178 i += 1;
179 }
180 max_positional.is_none_or(|max| positionals <= max)
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 static TEST_POLICY: FlagPolicy = FlagPolicy {
188 standalone: WordSet::flags(&[
189 "--color", "--count", "--help", "--recursive", "--version",
190 "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
191 ]),
192 valued: WordSet::flags(&[
193 "--after-context", "--before-context", "--max-count",
194 "-A", "-B", "-m",
195 ]),
196 bare: false,
197 max_positional: None,
198 flag_style: FlagStyle::Strict,
199 numeric_dash: false,
200 };
201
202 fn toks(words: &[&str]) -> Vec<Token> {
203 words.iter().map(|s| Token::from_test(s)).collect()
204 }
205
206 #[test]
207 fn bare_denied_when_bare_false() {
208 assert!(!check(&toks(&["grep"]), &TEST_POLICY));
209 }
210
211 #[test]
212 fn bare_allowed_when_bare_true() {
213 let policy = FlagPolicy {
214 standalone: WordSet::flags(&[]),
215 valued: WordSet::flags(&[]),
216 bare: true,
217 max_positional: None,
218 flag_style: FlagStyle::Strict,
219 numeric_dash: false,
220 };
221 assert!(check(&toks(&["uname"]), &policy));
222 }
223
224 #[test]
225 fn standalone_long_flag() {
226 assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
227 }
228
229 #[test]
230 fn standalone_short_flag() {
231 assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
232 }
233
234 #[test]
235 fn valued_long_flag_space() {
236 assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
237 }
238
239 #[test]
240 fn valued_long_flag_eq() {
241 assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
242 }
243
244 #[test]
245 fn valued_short_flag_space() {
246 assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
247 }
248
249 #[test]
250 fn combined_standalone_short() {
251 assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
252 }
253
254 #[test]
255 fn combined_short_with_valued_last() {
256 assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
257 }
258
259 #[test]
260 fn combined_short_valued_mid_consumes_rest() {
261 assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
262 }
263
264 #[test]
265 fn unknown_long_flag_denied() {
266 assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
267 }
268
269 #[test]
270 fn unknown_short_flag_denied() {
271 assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
272 }
273
274 #[test]
275 fn unknown_combined_short_denied() {
276 assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
277 }
278
279 #[test]
280 fn unknown_long_eq_denied() {
281 assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
282 }
283
284 #[test]
285 fn double_dash_stops_checking() {
286 assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
287 }
288
289 #[test]
290 fn positional_args_allowed() {
291 assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
292 }
293
294 #[test]
295 fn mixed_flags_and_positional() {
296 assert!(check(
297 &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
298 &TEST_POLICY,
299 ));
300 }
301
302 #[test]
303 fn valued_short_in_explicit_form() {
304 assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
305 }
306
307 #[test]
308 fn bare_dash_allowed_as_stdin() {
309 assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
310 }
311
312 #[test]
313 fn valued_flag_at_end_without_value() {
314 assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
315 }
316
317 #[test]
318 fn single_short_in_wordset_and_byte_array() {
319 assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
320 }
321
322 static LIMITED_POLICY: FlagPolicy = FlagPolicy {
323 standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
324 valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
325 bare: true,
326 max_positional: Some(1),
327 flag_style: FlagStyle::Strict,
328 numeric_dash: false,
329 };
330
331 #[test]
332 fn max_positional_within_limit() {
333 assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
334 }
335
336 #[test]
337 fn max_positional_exceeded() {
338 assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
339 }
340
341 #[test]
342 fn max_positional_with_flags_within_limit() {
343 assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
344 }
345
346 #[test]
347 fn max_positional_with_flags_exceeded() {
348 assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
349 }
350
351 #[test]
352 fn max_positional_after_double_dash() {
353 assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
354 }
355
356 #[test]
357 fn max_positional_bare_allowed() {
358 assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
359 }
360
361 static POSITIONAL_POLICY: FlagPolicy = FlagPolicy {
362 standalone: WordSet::flags(&["-E", "-e", "-n"]),
363 valued: WordSet::flags(&[]),
364 bare: true,
365 max_positional: None,
366 flag_style: FlagStyle::Positional,
367 numeric_dash: false,
368 };
369
370 #[test]
371 fn positional_style_unknown_long() {
372 assert!(check(&toks(&["echo", "--unknown", "hello"]), &POSITIONAL_POLICY));
373 }
374
375 #[test]
376 fn positional_style_unknown_short() {
377 assert!(check(&toks(&["echo", "-x", "hello"]), &POSITIONAL_POLICY));
378 }
379
380 #[test]
381 fn positional_style_dashes() {
382 assert!(check(&toks(&["echo", "---"]), &POSITIONAL_POLICY));
383 }
384
385 #[test]
386 fn positional_style_known_flags_still_work() {
387 assert!(check(&toks(&["echo", "-n", "hello"]), &POSITIONAL_POLICY));
388 }
389
390 #[test]
391 fn positional_style_combo_known() {
392 assert!(check(&toks(&["echo", "-ne", "hello"]), &POSITIONAL_POLICY));
393 }
394
395 #[test]
396 fn positional_style_combo_unknown_byte() {
397 assert!(check(&toks(&["echo", "-nx", "hello"]), &POSITIONAL_POLICY));
398 }
399
400 #[test]
401 fn positional_style_unknown_eq() {
402 assert!(check(&toks(&["echo", "--foo=bar"]), &POSITIONAL_POLICY));
403 }
404
405 #[test]
406 fn positional_style_with_max_positional() {
407 let policy = FlagPolicy {
408 standalone: WordSet::flags(&["-n"]),
409 valued: WordSet::flags(&[]),
410 bare: true,
411 max_positional: Some(2),
412 flag_style: FlagStyle::Positional,
413 numeric_dash: false,
414 };
415 assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
416 assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
417 }
418
419 static NUMERIC_DASH_POLICY: FlagPolicy = FlagPolicy {
420 standalone: WordSet::flags(&[
421 "--help", "--quiet", "--verbose", "--version",
422 "-V", "-h", "-q", "-v", "-z",
423 ]),
424 valued: WordSet::flags(&["--bytes", "--lines", "-c", "-n"]),
425 bare: true,
426 max_positional: None,
427 flag_style: FlagStyle::Strict,
428 numeric_dash: true,
429 };
430
431 #[test]
432 fn numeric_dash_single_digit() {
433 assert!(check(&toks(&["head", "-5"]), &NUMERIC_DASH_POLICY));
434 }
435
436 #[test]
437 fn numeric_dash_multi_digit() {
438 assert!(check(&toks(&["head", "-20"]), &NUMERIC_DASH_POLICY));
439 }
440
441 #[test]
442 fn numeric_dash_large_number() {
443 assert!(check(&toks(&["head", "-1000"]), &NUMERIC_DASH_POLICY));
444 }
445
446 #[test]
447 fn numeric_dash_with_file_arg() {
448 assert!(check(&toks(&["head", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
449 }
450
451 #[test]
452 fn numeric_dash_with_other_flags() {
453 assert!(check(&toks(&["head", "-q", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
454 }
455
456 #[test]
457 fn numeric_dash_zero() {
458 assert!(check(&toks(&["head", "-0"]), &NUMERIC_DASH_POLICY));
459 }
460
461 #[test]
462 fn numeric_dash_still_rejects_unknown_flags() {
463 assert!(!check(&toks(&["head", "-x"]), &NUMERIC_DASH_POLICY));
464 }
465
466 #[test]
467 fn numeric_dash_rejects_mixed_alpha_num() {
468 assert!(!check(&toks(&["head", "-20x"]), &NUMERIC_DASH_POLICY));
469 }
470
471 #[test]
472 fn numeric_dash_disabled_rejects_multi_digit() {
473 assert!(!check(&toks(&["grep", "-20", "pattern"]), &TEST_POLICY));
474 }
475}