1use std::collections::HashMap; use std::sync::Arc;
2use std::fmt;
3
4#[derive(Clone)]
5pub struct Arg {
6 pub name: String,
7 pub short: Option<char>,
8 pub long: Option<String>,
9 pub takes_value: bool,
10 pub required: bool,
11 pub default: Option<String>,
12 pub validator: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
13}
14
15impl fmt::Debug for Arg {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 f.debug_struct("Arg")
18 .field("name", &self.name)
19 .field("short", &self.short)
20 .field("long", &self.long)
21 .field("takes_value", &self.takes_value)
22 .field("required", &self.required)
23 .field("default", &self.default)
24 .finish()
25 }
26}
27
28#[derive(Debug)]
29pub struct ArgMatches {
30 pub values: HashMap<String, String>,
31 pub flags: HashMap<String, bool>,
32 pub positionals: Vec<String>,
33}
34
35pub struct ArgParser {
36 args: Vec<Arg>,
37 subcommands: HashMap<String, ArgParser>,
38}
39
40impl ArgParser {
41 pub fn new() -> Self {
42 Self {
43 args: Vec::new(),
44 subcommands: HashMap::new(),
45 }
46 }
47
48 pub fn arg(mut self, name: &str) -> Self {
49 self.args.push(Arg {
50 name: name.to_string(),
51 short: None,
52 long: None,
53 takes_value: false,
54 required: false,
55 default: None,
56 validator: None,
57 });
58 self
59 }
60
61 pub fn short(mut self, name: &str, short: char) -> Self {
62 if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
63 arg.short = Some(short);
64 }
65 self
66 }
67
68 pub fn long(mut self, name: &str, long: &str) -> Self {
69 if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
70 arg.long = Some(long.to_string());
71 }
72 self
73 }
74
75 pub fn takes_value(mut self, name: &str) -> Self {
76 if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
77 arg.takes_value = true;
78 }
79 self
80 }
81
82 pub fn required(mut self, name: &str) -> Self {
83 if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
84 arg.required = true;
85 }
86 self
87 }
88
89 pub fn default(mut self, name: &str, default: &str) -> Self {
90 if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
91 arg.default = Some(default.to_string());
92 }
93 self
94 }
95
96 pub fn validator<F>(mut self, name: &str, validator: F) -> Self
97 where
98 F: 'static + Fn(&str) -> bool + Send + Sync,
99 {
100 if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
101 arg.validator = Some(Arc::new(validator));
102 }
103 self
104 }
105
106 pub fn subcommand(mut self, name: &str, parser: ArgParser) -> Self {
107 self.subcommands.insert(name.to_string(), parser);
108 self
109 }
110
111 pub fn parse(mut self, args: &[String]) -> ArgMatches {
112 let mut values = HashMap::new();
113 let mut flags = HashMap::new();
114 let mut positionals = Vec::new();
115 let mut iter = args.iter().skip(1).peekable();
116
117 while let Some(arg) = iter.next() {
118 if arg.starts_with("--") {
119 let name = &arg[2..];
120 if let Some(a) = self.args.iter().find(|a| a.long.as_deref() == Some(name)) {
121 if a.takes_value {
122 if let Some(value) = iter.next() {
123 if let Some(validator) = &a.validator {
124 if !validator(value) {
125 panic!("Invalid value for argument: {}", name);
126 }
127 }
128 values.insert(a.name.clone(), value.clone());
129 }
130 } else {
131 flags.insert(a.name.clone(), true);
132 }
133 }
134 } else if arg.starts_with('-') {
135 let chars: Vec<char> = arg.chars().skip(1).collect();
136 for &c in &chars {
137 if let Some(a) = self.args.iter().find(|a| a.short == Some(c)) {
138 if a.takes_value {
139 if let Some(value) = iter.next() {
140 if let Some(validator) = &a.validator {
141 if !validator(value) {
142 panic!("Invalid value for argument: -{}", c);
143 }
144 }
145 values.insert(a.name.clone(), value.clone());
146 }
147 } else {
148 flags.insert(a.name.clone(), true);
149 }
150 }
151 }
152 } else if self.subcommands.contains_key(arg) {
153 let sub = self.subcommands.remove(arg).unwrap();
154 return sub.parse(&args[1..]);
155 } else {
156 positionals.push(arg.clone());
157 }
158 }
159
160 for arg in &self.args {
161 if arg.required && !values.contains_key(&arg.name) {
162 if let Some(default) = &arg.default {
163 values.insert(arg.name.clone(), default.clone());
164 } else {
165 panic!("Missing required argument: {}", arg.name);
166 }
167 }
168 }
169
170 ArgMatches {
171 values,
172 flags,
173 positionals,
174 }
175 }
176}