1use crate::parse::{Token, WordSet};
2
3#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
7pub enum UnknownTolerance {
8 #[default]
10 Strict,
11 Short,
15 Long,
21 Both,
24}
25
26impl UnknownTolerance {
27 pub const fn allows_short(self) -> bool {
28 matches!(self, Self::Short | Self::Both)
29 }
30 pub const fn allows_long(self) -> bool {
31 matches!(self, Self::Long | Self::Both)
32 }
33}
34
35#[derive(Clone, Copy, Debug, Default)]
39pub struct FlagTolerance {
40 pub unknown: UnknownTolerance,
41 pub numeric_dash: bool,
42}
43
44impl FlagTolerance {
45 pub const fn strict() -> Self {
48 Self { unknown: UnknownTolerance::Strict, numeric_dash: false }
49 }
50}
51
52pub trait FlagSet {
53 fn contains_flag(&self, token: &str) -> bool;
54 fn contains_short(&self, byte: u8) -> bool;
55}
56
57impl FlagSet for WordSet {
58 fn contains_flag(&self, token: &str) -> bool {
59 self.contains(token)
60 }
61 fn contains_short(&self, byte: u8) -> bool {
62 self.contains_short(byte)
63 }
64}
65
66impl FlagSet for [String] {
67 fn contains_flag(&self, token: &str) -> bool {
68 self.iter().any(|f| f.as_str() == token)
69 }
70 fn contains_short(&self, byte: u8) -> bool {
71 self.iter().any(|f| f.len() == 2 && f.as_bytes()[1] == byte)
72 }
73}
74
75impl FlagSet for Vec<String> {
76 fn contains_flag(&self, token: &str) -> bool {
77 self.as_slice().contains_flag(token)
78 }
79 fn contains_short(&self, byte: u8) -> bool {
80 self.as_slice().contains_short(byte)
81 }
82}
83
84pub struct FlagPolicy {
85 pub standalone: WordSet,
86 pub valued: WordSet,
87 pub bare: bool,
88 pub max_positional: Option<usize>,
89 pub tolerance: FlagTolerance,
90}
91
92impl FlagPolicy {
93 pub fn describe(&self) -> String {
94 use crate::docs::wordset_items;
95 let mut lines = Vec::new();
96 let standalone = wordset_items(&self.standalone);
97 if !standalone.is_empty() {
98 lines.push(format!("- Allowed standalone flags: {standalone}"));
99 }
100 let valued = wordset_items(&self.valued);
101 if !valued.is_empty() {
102 lines.push(format!("- Allowed valued flags: {valued}"));
103 }
104 if self.bare {
105 lines.push("- Bare invocation allowed".to_string());
106 }
107 if self.tolerance.unknown != UnknownTolerance::Strict {
108 lines.push("- Hyphen-prefixed positional arguments accepted".to_string());
109 }
110 if self.tolerance.numeric_dash {
111 lines.push("- Numeric shorthand accepted (e.g. -20 for -n 20)".to_string());
112 }
113 if lines.is_empty() && !self.bare {
114 return "- Positional arguments only".to_string();
115 }
116 lines.join("\n")
117 }
118
119}
120
121pub fn check(tokens: &[Token], policy: &FlagPolicy) -> bool {
122 check_flags(
123 tokens,
124 &policy.standalone,
125 &policy.valued,
126 policy.bare,
127 policy.max_positional,
128 policy.tolerance,
129 )
130}
131
132pub fn check_flags<S: FlagSet + ?Sized, V: FlagSet + ?Sized>(
133 tokens: &[Token],
134 standalone: &S,
135 valued: &V,
136 bare: bool,
137 max_positional: Option<usize>,
138 tolerance: FlagTolerance,
139) -> bool {
140 if tokens.len() == 1 {
141 return bare;
142 }
143
144 let mut i = 1;
145 let mut positionals: usize = 0;
146 while i < tokens.len() {
147 let t = &tokens[i];
148
149 if *t == "--" {
150 positionals += tokens.len() - i - 1;
151 break;
152 }
153
154 if !t.starts_with('-') {
155 positionals += 1;
156 i += 1;
157 continue;
158 }
159
160 if tolerance.numeric_dash && t.len() > 1 && t[1..].bytes().all(|b| b.is_ascii_digit()) {
161 i += 1;
162 continue;
163 }
164
165 if standalone.contains_flag(t) {
166 i += 1;
167 continue;
168 }
169
170 if valued.contains_flag(t) {
171 i += 2;
172 continue;
173 }
174
175 if let Some(flag) = t.as_str().split_once('=').map(|(f, _)| f) {
176 if valued.contains_flag(flag) {
177 i += 1;
178 continue;
179 }
180 if tolerance.unknown.allows_long() {
182 positionals += 1;
183 i += 1;
184 continue;
185 }
186 return false;
187 }
188
189 if t.starts_with("--") {
190 if tolerance.unknown.allows_long() {
191 positionals += 1;
192 i += 1;
193 continue;
194 }
195 return false;
196 }
197
198 let bytes = t.as_bytes();
199 let mut j = 1;
200 while j < bytes.len() {
201 let b = bytes[j];
202 let is_last = j == bytes.len() - 1;
203 if standalone.contains_short(b) {
204 j += 1;
205 continue;
206 }
207 if valued.contains_short(b) {
208 if is_last {
209 i += 1;
210 }
211 break;
212 }
213 if tolerance.unknown.allows_short() {
214 positionals += 1;
215 break;
216 }
217 return false;
218 }
219 i += 1;
220 }
221 max_positional.is_none_or(|max| positionals <= max)
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 static TEST_POLICY: FlagPolicy = FlagPolicy {
229 standalone: WordSet::flags(&[
230 "--color", "--count", "--help", "--recursive", "--version",
231 "-H", "-c", "-i", "-l", "-n", "-o", "-r", "-s", "-v", "-w",
232 ]),
233 valued: WordSet::flags(&[
234 "--after-context", "--before-context", "--max-count",
235 "-A", "-B", "-m",
236 ]),
237 bare: false,
238 max_positional: None,
239 tolerance: FlagTolerance::strict(),
240 };
241
242 fn toks(words: &[&str]) -> Vec<Token> {
243 words.iter().map(|s| Token::from_test(s)).collect()
244 }
245
246 #[test]
247 fn bare_denied_when_bare_false() {
248 assert!(!check(&toks(&["grep"]), &TEST_POLICY));
249 }
250
251 #[test]
252 fn bare_allowed_when_bare_true() {
253 let policy = FlagPolicy {
254 standalone: WordSet::flags(&[]),
255 valued: WordSet::flags(&[]),
256 bare: true,
257 max_positional: None,
258 tolerance: FlagTolerance::strict(),
259 };
260 assert!(check(&toks(&["uname"]), &policy));
261 }
262
263 #[test]
264 fn standalone_long_flag() {
265 assert!(check(&toks(&["grep", "--recursive", "pattern", "."]), &TEST_POLICY));
266 }
267
268 #[test]
269 fn standalone_short_flag() {
270 assert!(check(&toks(&["grep", "-r", "pattern", "."]), &TEST_POLICY));
271 }
272
273 #[test]
274 fn valued_long_flag_space() {
275 assert!(check(&toks(&["grep", "--max-count", "5", "pattern"]), &TEST_POLICY));
276 }
277
278 #[test]
279 fn valued_long_flag_eq() {
280 assert!(check(&toks(&["grep", "--max-count=5", "pattern"]), &TEST_POLICY));
281 }
282
283 #[test]
284 fn valued_short_flag_space() {
285 assert!(check(&toks(&["grep", "-m", "5", "pattern"]), &TEST_POLICY));
286 }
287
288 #[test]
289 fn combined_standalone_short() {
290 assert!(check(&toks(&["grep", "-rn", "pattern", "."]), &TEST_POLICY));
291 }
292
293 #[test]
294 fn combined_short_with_valued_last() {
295 assert!(check(&toks(&["grep", "-rnm", "5", "pattern"]), &TEST_POLICY));
296 }
297
298 #[test]
299 fn combined_short_valued_mid_consumes_rest() {
300 assert!(check(&toks(&["grep", "-rmn", "pattern"]), &TEST_POLICY));
301 }
302
303 #[test]
304 fn unknown_long_flag_denied() {
305 assert!(!check(&toks(&["grep", "--exec", "cmd"]), &TEST_POLICY));
306 }
307
308 #[test]
309 fn unknown_short_flag_denied() {
310 assert!(!check(&toks(&["grep", "-z", "pattern"]), &TEST_POLICY));
311 }
312
313 #[test]
314 fn unknown_combined_short_denied() {
315 assert!(!check(&toks(&["grep", "-rz", "pattern"]), &TEST_POLICY));
316 }
317
318 #[test]
319 fn unknown_long_eq_denied() {
320 assert!(!check(&toks(&["grep", "--output=file.txt", "pattern"]), &TEST_POLICY));
321 }
322
323 #[test]
324 fn double_dash_stops_checking() {
325 assert!(check(&toks(&["grep", "--", "--not-a-flag", "file"]), &TEST_POLICY));
326 }
327
328 #[test]
329 fn positional_args_allowed() {
330 assert!(check(&toks(&["grep", "pattern", "file.txt", "other.txt"]), &TEST_POLICY));
331 }
332
333 #[test]
334 fn mixed_flags_and_positional() {
335 assert!(check(
336 &toks(&["grep", "-rn", "--color", "--max-count", "10", "pattern", "."]),
337 &TEST_POLICY,
338 ));
339 }
340
341 #[test]
342 fn valued_short_in_explicit_form() {
343 assert!(check(&toks(&["grep", "-A", "3", "-B", "3", "pattern"]), &TEST_POLICY));
344 }
345
346 #[test]
347 fn bare_dash_allowed_as_stdin() {
348 assert!(check(&toks(&["grep", "pattern", "-"]), &TEST_POLICY));
349 }
350
351 #[test]
352 fn valued_flag_at_end_without_value() {
353 assert!(check(&toks(&["grep", "--max-count"]), &TEST_POLICY));
354 }
355
356 #[test]
357 fn single_short_in_wordset_and_byte_array() {
358 assert!(check(&toks(&["grep", "-c", "pattern"]), &TEST_POLICY));
359 }
360
361 static LIMITED_POLICY: FlagPolicy = FlagPolicy {
362 standalone: WordSet::flags(&["--count", "-c", "-d", "-i", "-u"]),
363 valued: WordSet::flags(&["--skip-fields", "-f", "-s"]),
364 bare: true,
365 max_positional: Some(1),
366 tolerance: FlagTolerance::strict(),
367 };
368
369 #[test]
370 fn max_positional_within_limit() {
371 assert!(check(&toks(&["uniq", "input.txt"]), &LIMITED_POLICY));
372 }
373
374 #[test]
375 fn max_positional_exceeded() {
376 assert!(!check(&toks(&["uniq", "input.txt", "output.txt"]), &LIMITED_POLICY));
377 }
378
379 #[test]
380 fn max_positional_with_flags_within_limit() {
381 assert!(check(&toks(&["uniq", "-c", "-f", "3", "input.txt"]), &LIMITED_POLICY));
382 }
383
384 #[test]
385 fn max_positional_with_flags_exceeded() {
386 assert!(!check(&toks(&["uniq", "-c", "input.txt", "output.txt"]), &LIMITED_POLICY));
387 }
388
389 #[test]
390 fn max_positional_after_double_dash() {
391 assert!(!check(&toks(&["uniq", "--", "input.txt", "output.txt"]), &LIMITED_POLICY));
392 }
393
394 #[test]
395 fn max_positional_bare_allowed() {
396 assert!(check(&toks(&["uniq"]), &LIMITED_POLICY));
397 }
398
399 static BOTH_TOLERANCES_POLICY: FlagPolicy = FlagPolicy {
400 standalone: WordSet::flags(&["-E", "-e", "-n"]),
401 valued: WordSet::flags(&[]),
402 bare: true,
403 max_positional: None,
404 tolerance: FlagTolerance { unknown: UnknownTolerance::Both, numeric_dash: false },
405 };
406
407 #[test]
408 fn both_tolerances_accept_unknown_long() {
409 assert!(check(&toks(&["echo", "--unknown", "hello"]), &BOTH_TOLERANCES_POLICY));
410 }
411
412 #[test]
413 fn both_tolerances_accept_unknown_short() {
414 assert!(check(&toks(&["echo", "-x", "hello"]), &BOTH_TOLERANCES_POLICY));
415 }
416
417 #[test]
418 fn both_tolerances_accept_triple_dash() {
419 assert!(check(&toks(&["echo", "---"]), &BOTH_TOLERANCES_POLICY));
420 }
421
422 #[test]
423 fn both_tolerances_known_flags_still_work() {
424 assert!(check(&toks(&["echo", "-n", "hello"]), &BOTH_TOLERANCES_POLICY));
425 }
426
427 #[test]
428 fn both_tolerances_combo_known_short() {
429 assert!(check(&toks(&["echo", "-ne", "hello"]), &BOTH_TOLERANCES_POLICY));
430 }
431
432 #[test]
433 fn both_tolerances_combo_unknown_short_byte() {
434 assert!(check(&toks(&["echo", "-nx", "hello"]), &BOTH_TOLERANCES_POLICY));
435 }
436
437 #[test]
438 fn both_tolerances_unknown_eq_form() {
439 assert!(check(&toks(&["echo", "--foo=bar"]), &BOTH_TOLERANCES_POLICY));
440 }
441
442 static SHORT_ONLY_POLICY: FlagPolicy = FlagPolicy {
449 standalone: WordSet::flags(&["--help"]),
450 valued: WordSet::flags(&[]),
451 bare: false,
452 max_positional: None,
453 tolerance: FlagTolerance { unknown: UnknownTolerance::Short, numeric_dash: false },
454 };
455
456 #[test]
457 fn short_only_accepts_unknown_dash_letter() {
458 assert!(check(&toks(&["sample", "-mayDie"]), &SHORT_ONLY_POLICY));
459 }
460
461 #[test]
462 fn short_only_accepts_single_dash_long_word() {
463 assert!(check(&toks(&["pdftotext", "-layout"]), &SHORT_ONLY_POLICY));
465 }
466
467 #[test]
468 fn short_only_denies_unknown_double_dash() {
469 assert!(!check(&toks(&["sample", "--evil-flag"]), &SHORT_ONLY_POLICY));
472 }
473
474 #[test]
475 fn short_only_denies_unknown_eq_form() {
476 assert!(!check(&toks(&["sample", "--evil=value"]), &SHORT_ONLY_POLICY));
477 }
478
479 #[test]
480 fn short_only_known_long_flag_still_works() {
481 assert!(check(&toks(&["sample", "--help"]), &SHORT_ONLY_POLICY));
482 }
483
484 static LONG_ONLY_POLICY: FlagPolicy = FlagPolicy {
490 standalone: WordSet::flags(&["--help"]),
491 valued: WordSet::flags(&[]),
492 bare: false,
493 max_positional: None,
494 tolerance: FlagTolerance { unknown: UnknownTolerance::Long, numeric_dash: false },
495 };
496
497 #[test]
498 fn long_only_accepts_unknown_double_dash() {
499 assert!(check(&toks(&["aws", "--some-aws-flag"]), &LONG_ONLY_POLICY));
500 }
501
502 #[test]
503 fn long_only_accepts_unknown_eq_form() {
504 assert!(check(
505 &toks(&["aws", "--filter=Name=tag,Values=foo"]),
506 &LONG_ONLY_POLICY,
507 ));
508 }
509
510 #[test]
511 fn long_only_denies_unknown_short_dash() {
512 assert!(!check(&toks(&["aws", "-x"]), &LONG_ONLY_POLICY));
513 }
514
515 static STRICT_POLICY: FlagPolicy = FlagPolicy {
518 standalone: WordSet::flags(&["--help"]),
519 valued: WordSet::flags(&[]),
520 bare: false,
521 max_positional: None,
522 tolerance: FlagTolerance::strict(),
523 };
524
525 #[test]
526 fn strict_denies_unknown_short() {
527 assert!(!check(&toks(&["foo", "-evil"]), &STRICT_POLICY));
528 }
529
530 #[test]
531 fn strict_denies_unknown_long() {
532 assert!(!check(&toks(&["foo", "--evil"]), &STRICT_POLICY));
533 }
534
535 #[test]
536 fn strict_known_flag_passes() {
537 assert!(check(&toks(&["foo", "--help"]), &STRICT_POLICY));
538 }
539
540 #[test]
541 fn both_tolerances_with_max_positional() {
542 let policy = FlagPolicy {
543 standalone: WordSet::flags(&["-n"]),
544 valued: WordSet::flags(&[]),
545 bare: true,
546 max_positional: Some(2),
547 tolerance: FlagTolerance { unknown: UnknownTolerance::Both, numeric_dash: false },
548 };
549 assert!(check(&toks(&["echo", "--unknown", "hello"]), &policy));
550 assert!(!check(&toks(&["echo", "--a", "--b", "--c"]), &policy));
551 }
552
553 static NUMERIC_DASH_POLICY: FlagPolicy = FlagPolicy {
554 standalone: WordSet::flags(&[
555 "--help", "--quiet", "--verbose", "--version",
556 "-V", "-h", "-q", "-v", "-z",
557 ]),
558 valued: WordSet::flags(&["--bytes", "--lines", "-c", "-n"]),
559 bare: true,
560 max_positional: None,
561 tolerance: FlagTolerance { numeric_dash: true, ..FlagTolerance::strict() },
562 };
563
564 #[test]
565 fn numeric_dash_single_digit() {
566 assert!(check(&toks(&["head", "-5"]), &NUMERIC_DASH_POLICY));
567 }
568
569 #[test]
570 fn numeric_dash_multi_digit() {
571 assert!(check(&toks(&["head", "-20"]), &NUMERIC_DASH_POLICY));
572 }
573
574 #[test]
575 fn numeric_dash_large_number() {
576 assert!(check(&toks(&["head", "-1000"]), &NUMERIC_DASH_POLICY));
577 }
578
579 #[test]
580 fn numeric_dash_with_file_arg() {
581 assert!(check(&toks(&["head", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
582 }
583
584 #[test]
585 fn numeric_dash_with_other_flags() {
586 assert!(check(&toks(&["head", "-q", "-20", "file.txt"]), &NUMERIC_DASH_POLICY));
587 }
588
589 #[test]
590 fn numeric_dash_zero() {
591 assert!(check(&toks(&["head", "-0"]), &NUMERIC_DASH_POLICY));
592 }
593
594 #[test]
595 fn numeric_dash_still_rejects_unknown_flags() {
596 assert!(!check(&toks(&["head", "-x"]), &NUMERIC_DASH_POLICY));
597 }
598
599 #[test]
600 fn numeric_dash_rejects_mixed_alpha_num() {
601 assert!(!check(&toks(&["head", "-20x"]), &NUMERIC_DASH_POLICY));
602 }
603
604 #[test]
605 fn numeric_dash_disabled_rejects_multi_digit() {
606 assert!(!check(&toks(&["grep", "-20", "pattern"]), &TEST_POLICY));
607 }
608}