Skip to main content

zyn_core/meta/
arg.rs

1use proc_macro2::Ident;
2
3use quote::ToTokens;
4use quote::TokenStreamExt;
5
6use syn::Expr;
7use syn::Lit;
8use syn::Token;
9use syn::parse::Parse;
10use syn::parse::ParseStream;
11
12use super::Args;
13
14#[derive(Clone)]
15pub enum Arg {
16    Flag(Ident),
17    Expr(Ident, Expr),
18    List(Ident, Args),
19    Lit(Lit),
20}
21
22impl Arg {
23    pub fn name(&self) -> Option<&Ident> {
24        match self {
25            Self::Flag(name) => Some(name),
26            Self::Expr(name, _) => Some(name),
27            Self::List(name, _) => Some(name),
28            Self::Lit(_) => None,
29        }
30    }
31
32    pub fn is_flag(&self) -> bool {
33        matches!(self, Self::Flag(_))
34    }
35
36    pub fn is_expr(&self) -> bool {
37        matches!(self, Self::Expr(_, _))
38    }
39
40    pub fn is_list(&self) -> bool {
41        matches!(self, Self::List(_, _))
42    }
43
44    pub fn is_lit(&self) -> bool {
45        matches!(self, Self::Lit(_))
46    }
47
48    pub fn as_expr(&self) -> &Expr {
49        match self {
50            Self::Expr(_, expr) => expr,
51            _ => panic!("called `Arg::as_expr()` on a non-Expr variant"),
52        }
53    }
54
55    pub fn as_args(&self) -> &Args {
56        match self {
57            Self::List(_, args) => args,
58            _ => panic!("called `Arg::as_args()` on a non-List variant"),
59        }
60    }
61
62    pub fn as_lit(&self) -> &Lit {
63        match self {
64            Self::Lit(lit) => lit,
65            _ => panic!("called `Arg::as_lit()` on a non-Lit variant"),
66        }
67    }
68
69    pub fn as_flag(&self) -> &Ident {
70        match self {
71            Self::Flag(i) => i,
72            _ => panic!("called `Arg::as_flag()` on a non-Flag variant"),
73        }
74    }
75
76    pub fn as_str(&self) -> String {
77        match self {
78            Self::Lit(Lit::Str(s)) => s.value(),
79            Self::Expr(
80                _,
81                syn::Expr::Lit(syn::ExprLit {
82                    lit: Lit::Str(s), ..
83                }),
84            ) => s.value(),
85            _ => panic!("called `Arg::as_str()` on a non-string variant"),
86        }
87    }
88
89    pub fn as_int<T: std::str::FromStr>(&self) -> T
90    where
91        T::Err: std::fmt::Display,
92    {
93        match self {
94            Self::Lit(Lit::Int(i)) => i.base10_parse().expect("invalid integer literal"),
95            Self::Expr(
96                _,
97                syn::Expr::Lit(syn::ExprLit {
98                    lit: Lit::Int(i), ..
99                }),
100            ) => i.base10_parse().expect("invalid integer literal"),
101            _ => panic!("called `Arg::as_int()` on a non-integer variant"),
102        }
103    }
104
105    pub fn as_float<T: std::str::FromStr>(&self) -> T
106    where
107        T::Err: std::fmt::Display,
108    {
109        match self {
110            Self::Lit(Lit::Float(f)) => f.base10_parse().expect("invalid float literal"),
111            Self::Expr(
112                _,
113                syn::Expr::Lit(syn::ExprLit {
114                    lit: Lit::Float(f), ..
115                }),
116            ) => f.base10_parse().expect("invalid float literal"),
117            _ => panic!("called `Arg::as_float()` on a non-float variant"),
118        }
119    }
120
121    pub fn as_char(&self) -> char {
122        match self {
123            Self::Lit(Lit::Char(c)) => c.value(),
124            Self::Expr(
125                _,
126                syn::Expr::Lit(syn::ExprLit {
127                    lit: Lit::Char(c), ..
128                }),
129            ) => c.value(),
130            _ => panic!("called `Arg::as_char()` on a non-char variant"),
131        }
132    }
133
134    pub fn as_expr_lit(&self) -> Option<&Lit> {
135        match self {
136            Self::Lit(lit) => Some(lit),
137            Self::Expr(_, Expr::Lit(syn::ExprLit { lit, .. })) => Some(lit),
138            _ => None,
139        }
140    }
141
142    pub fn as_bool(&self) -> bool {
143        match self {
144            Self::Lit(Lit::Bool(b)) => b.value,
145            Self::Expr(
146                _,
147                syn::Expr::Lit(syn::ExprLit {
148                    lit: Lit::Bool(b), ..
149                }),
150            ) => b.value,
151            _ => panic!("called `Arg::as_bool()` on a non-bool variant"),
152        }
153    }
154}
155
156impl Parse for Arg {
157    fn parse(input: ParseStream) -> syn::Result<Self> {
158        if input.peek(Lit) || input.peek(syn::LitStr) {
159            return Ok(Self::Lit(input.parse()?));
160        }
161
162        let name: Ident = input.parse()?;
163
164        if input.peek(Token![=]) {
165            input.parse::<Token![=]>()?;
166            let expr: Expr = input.parse()?;
167            Ok(Self::Expr(name, expr))
168        } else if input.peek(syn::token::Paren) {
169            let content;
170            syn::parenthesized!(content in input);
171            let args: Args = content.parse()?;
172            Ok(Self::List(name, args))
173        } else {
174            Ok(Self::Flag(name))
175        }
176    }
177}
178
179impl ToTokens for Arg {
180    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
181        match self {
182            Self::Flag(name) => name.to_tokens(tokens),
183            Self::Expr(name, expr) => {
184                name.to_tokens(tokens);
185                tokens.append(proc_macro2::Punct::new('=', proc_macro2::Spacing::Alone));
186                expr.to_tokens(tokens);
187            }
188            Self::List(name, args) => {
189                name.to_tokens(tokens);
190                let inner = args.to_token_stream();
191                tokens.append(proc_macro2::Group::new(
192                    proc_macro2::Delimiter::Parenthesis,
193                    inner,
194                ));
195            }
196            Self::Lit(lit) => lit.to_tokens(tokens),
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    mod parse {
206        use super::*;
207
208        #[test]
209        fn flag() {
210            let arg: Arg = syn::parse_str("skip").unwrap();
211            assert!(arg.is_flag());
212            assert_eq!(arg.name().unwrap(), "skip");
213        }
214
215        #[test]
216        fn expr() {
217            let arg: Arg = syn::parse_str("rename = \"foo\"").unwrap();
218            assert!(arg.is_expr());
219            assert_eq!(arg.name().unwrap(), "rename");
220        }
221
222        #[test]
223        fn list() {
224            let arg: Arg = syn::parse_str("serde(rename_all = \"camelCase\")").unwrap();
225            assert!(arg.is_list());
226            assert_eq!(arg.name().unwrap(), "serde");
227            assert!(arg.as_args().has("rename_all"));
228        }
229
230        #[test]
231        fn lit_str() {
232            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
233            assert!(arg.is_lit());
234            assert!(arg.name().is_none());
235        }
236
237        #[test]
238        fn lit_int() {
239            let arg: Arg = syn::parse_str("42").unwrap();
240            assert!(arg.is_lit());
241        }
242    }
243
244    mod to_tokens {
245        use super::*;
246
247        #[test]
248        fn flag() {
249            let arg: Arg = syn::parse_str("skip").unwrap();
250            let output = arg.to_token_stream().to_string();
251            assert_eq!(output, "skip");
252        }
253
254        #[test]
255        fn expr() {
256            let arg: Arg = syn::parse_str("rename = \"foo\"").unwrap();
257            let output = arg.to_token_stream().to_string();
258            assert_eq!(output, "rename = \"foo\"");
259        }
260
261        #[test]
262        fn list() {
263            let arg: Arg = syn::parse_str("serde(skip)").unwrap();
264            let output = arg.to_token_stream().to_string();
265            assert_eq!(output, "serde (skip)");
266        }
267
268        #[test]
269        fn lit() {
270            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
271            let output = arg.to_token_stream().to_string();
272            assert_eq!(output, "\"hello\"");
273        }
274    }
275
276    mod accessors {
277        use super::*;
278
279        #[test]
280        #[should_panic(expected = "non-Expr")]
281        fn as_expr_panics_on_flag() {
282            let arg: Arg = syn::parse_str("skip").unwrap();
283            arg.as_expr();
284        }
285
286        #[test]
287        #[should_panic(expected = "non-List")]
288        fn as_args_panics_on_flag() {
289            let arg: Arg = syn::parse_str("skip").unwrap();
290            arg.as_args();
291        }
292
293        #[test]
294        #[should_panic(expected = "non-Lit")]
295        fn as_lit_panics_on_flag() {
296            let arg: Arg = syn::parse_str("skip").unwrap();
297            arg.as_lit();
298        }
299
300        #[test]
301        fn as_flag_returns_ident() {
302            let arg: Arg = syn::parse_str("skip").unwrap();
303            assert_eq!(arg.as_flag().to_string(), "skip");
304        }
305
306        #[test]
307        #[should_panic(expected = "non-Flag")]
308        fn as_flag_panics_on_expr() {
309            let arg: Arg = syn::parse_str("x = 1").unwrap();
310            arg.as_flag();
311        }
312
313        #[test]
314        fn as_str_from_lit() {
315            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
316            assert_eq!(arg.as_str(), "hello");
317        }
318
319        #[test]
320        fn as_str_from_expr() {
321            let arg: Arg = syn::parse_str("rename = \"foo\"").unwrap();
322            assert_eq!(arg.as_str(), "foo");
323        }
324
325        #[test]
326        #[should_panic(expected = "non-string")]
327        fn as_str_panics_on_flag() {
328            let arg: Arg = syn::parse_str("skip").unwrap();
329            arg.as_str();
330        }
331
332        #[test]
333        fn as_int_from_lit() {
334            let arg: Arg = syn::parse_str("42").unwrap();
335            assert_eq!(arg.as_int::<i64>(), 42i64);
336        }
337
338        #[test]
339        fn as_int_from_expr() {
340            let arg: Arg = syn::parse_str("count = 7").unwrap();
341            assert_eq!(arg.as_int::<i64>(), 7i64);
342        }
343
344        #[test]
345        #[should_panic(expected = "non-integer")]
346        fn as_int_panics_on_string_lit() {
347            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
348            arg.as_int::<i64>();
349        }
350
351        #[test]
352        fn as_char_from_lit() {
353            let arg: Arg = syn::parse_str("'x'").unwrap();
354            assert_eq!(arg.as_char(), 'x');
355        }
356
357        #[test]
358        #[should_panic(expected = "non-char")]
359        fn as_char_panics_on_flag() {
360            let arg: Arg = syn::parse_str("skip").unwrap();
361            arg.as_char();
362        }
363    }
364}