rust_arguments/
lib.rs

1use std::collections::HashMap;                             use std::sync::Arc;
2use std::fmt;
3
4#[derive(Clone)]
5pub struct Arg {
6    pub name: String,
7    pub short: Option<char>,
8    pub long: Option<String>,
9    pub takes_value: bool,
10    pub required: bool,
11    pub default: Option<String>,
12    pub validator: Option<Arc<dyn Fn(&str) -> bool + Send + Sync>>,
13}
14
15impl fmt::Debug for Arg {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        f.debug_struct("Arg")
18            .field("name", &self.name)
19            .field("short", &self.short)
20            .field("long", &self.long)
21            .field("takes_value", &self.takes_value)
22            .field("required", &self.required)
23            .field("default", &self.default)
24            .finish()
25    }
26}
27
28#[derive(Debug)]
29pub struct ArgMatches {
30    pub values: HashMap<String, String>,
31    pub flags: HashMap<String, bool>,
32    pub positionals: Vec<String>,
33}
34
35pub struct ArgParser {
36    args: Vec<Arg>,
37    subcommands: HashMap<String, ArgParser>,
38}
39
40impl ArgParser {
41    pub fn new() -> Self {
42        Self {
43            args: Vec::new(),
44            subcommands: HashMap::new(),
45        }
46    }
47
48    pub fn arg(mut self, name: &str) -> Self {
49        self.args.push(Arg {
50            name: name.to_string(),
51            short: None,
52            long: None,
53            takes_value: false,
54            required: false,
55            default: None,
56            validator: None,
57        });
58        self
59    }
60
61    pub fn short(mut self, name: &str, short: char) -> Self {
62        if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
63            arg.short = Some(short);
64        }
65        self
66    }
67
68    pub fn long(mut self, name: &str, long: &str) -> Self {
69        if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
70            arg.long = Some(long.to_string());
71        }
72        self
73    }
74
75    pub fn takes_value(mut self, name: &str) -> Self {
76        if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
77            arg.takes_value = true;
78        }
79        self
80    }
81
82    pub fn required(mut self, name: &str) -> Self {
83        if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
84            arg.required = true;
85        }
86        self
87    }
88
89    pub fn default(mut self, name: &str, default: &str) -> Self {
90        if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
91            arg.default = Some(default.to_string());
92        }
93        self
94    }
95
96    pub fn validator<F>(mut self, name: &str, validator: F) -> Self
97    where
98        F: 'static + Fn(&str) -> bool + Send + Sync,
99    {
100        if let Some(arg) = self.args.iter_mut().find(|a| a.name == name) {
101            arg.validator = Some(Arc::new(validator));
102        }
103        self
104    }
105
106    pub fn subcommand(mut self, name: &str, parser: ArgParser) -> Self {
107        self.subcommands.insert(name.to_string(), parser);
108        self
109    }
110
111    pub fn parse(mut self, args: &[String]) -> ArgMatches {
112        let mut values = HashMap::new();
113        let mut flags = HashMap::new();
114        let mut positionals = Vec::new();
115        let mut iter = args.iter().skip(1).peekable();
116
117        while let Some(arg) = iter.next() {
118            if arg.starts_with("--") {
119                let name = &arg[2..];
120                if let Some(a) = self.args.iter().find(|a| a.long.as_deref() == Some(name)) {
121                    if a.takes_value {
122                        if let Some(value) = iter.next() {
123                            if let Some(validator) = &a.validator {
124                                if !validator(value) {
125                                    panic!("Invalid value for argument: {}", name);
126                                }
127                            }
128                            values.insert(a.name.clone(), value.clone());
129                        }
130                    } else {
131                        flags.insert(a.name.clone(), true);
132                    }
133                }
134            } else if arg.starts_with('-') {
135                let chars: Vec<char> = arg.chars().skip(1).collect();
136                for &c in &chars {
137                    if let Some(a) = self.args.iter().find(|a| a.short == Some(c)) {
138                        if a.takes_value {
139                            if let Some(value) = iter.next() {
140                                if let Some(validator) = &a.validator {
141                                    if !validator(value) {
142                                        panic!("Invalid value for argument: -{}", c);
143                                    }
144                                }
145                                values.insert(a.name.clone(), value.clone());
146                            }
147                        } else {
148                            flags.insert(a.name.clone(), true);
149                        }
150                    }
151                }
152            } else if self.subcommands.contains_key(arg) {
153                let sub = self.subcommands.remove(arg).unwrap();
154                return sub.parse(&args[1..]);
155            } else {
156                positionals.push(arg.clone());
157            }
158        }
159
160        for arg in &self.args {
161            if arg.required && !values.contains_key(&arg.name) {
162                if let Some(default) = &arg.default {
163                    values.insert(arg.name.clone(), default.clone());
164                } else {
165                    panic!("Missing required argument: {}", arg.name);
166                }
167            }
168        }
169
170        ArgMatches {
171            values,
172            flags,
173            positionals,
174        }
175    }
176}