Skip to main content

zyn_core/meta/
arg.rs

1use proc_macro2::Ident;
2
3use quote::ToTokens;
4use quote::TokenStreamExt;
5
6use syn::Expr;
7use syn::Lit;
8use syn::Token;
9use syn::parse::Parse;
10use syn::parse::ParseStream;
11
12use super::Args;
13
14pub enum Arg {
15    Flag(Ident),
16    Expr(Ident, Expr),
17    List(Ident, Args),
18    Lit(Lit),
19}
20
21impl Arg {
22    pub fn name(&self) -> Option<&Ident> {
23        match self {
24            Self::Flag(name) => Some(name),
25            Self::Expr(name, _) => Some(name),
26            Self::List(name, _) => Some(name),
27            Self::Lit(_) => None,
28        }
29    }
30
31    pub fn is_flag(&self) -> bool {
32        matches!(self, Self::Flag(_))
33    }
34
35    pub fn is_expr(&self) -> bool {
36        matches!(self, Self::Expr(_, _))
37    }
38
39    pub fn is_list(&self) -> bool {
40        matches!(self, Self::List(_, _))
41    }
42
43    pub fn is_lit(&self) -> bool {
44        matches!(self, Self::Lit(_))
45    }
46
47    pub fn as_expr(&self) -> &Expr {
48        match self {
49            Self::Expr(_, expr) => expr,
50            _ => panic!("called `Arg::as_expr()` on a non-Expr variant"),
51        }
52    }
53
54    pub fn as_args(&self) -> &Args {
55        match self {
56            Self::List(_, args) => args,
57            _ => panic!("called `Arg::as_args()` on a non-List variant"),
58        }
59    }
60
61    pub fn as_lit(&self) -> &Lit {
62        match self {
63            Self::Lit(lit) => lit,
64            _ => panic!("called `Arg::as_lit()` on a non-Lit variant"),
65        }
66    }
67
68    pub fn as_flag(&self) -> &Ident {
69        match self {
70            Self::Flag(i) => i,
71            _ => panic!("called `Arg::as_flag()` on a non-Flag variant"),
72        }
73    }
74
75    pub fn as_str(&self) -> String {
76        match self {
77            Self::Lit(Lit::Str(s)) => s.value(),
78            Self::Expr(
79                _,
80                syn::Expr::Lit(syn::ExprLit {
81                    lit: Lit::Str(s), ..
82                }),
83            ) => s.value(),
84            _ => panic!("called `Arg::as_str()` on a non-string variant"),
85        }
86    }
87
88    pub fn as_int<T: std::str::FromStr>(&self) -> T
89    where
90        T::Err: std::fmt::Display,
91    {
92        match self {
93            Self::Lit(Lit::Int(i)) => i.base10_parse().expect("invalid integer literal"),
94            Self::Expr(
95                _,
96                syn::Expr::Lit(syn::ExprLit {
97                    lit: Lit::Int(i), ..
98                }),
99            ) => i.base10_parse().expect("invalid integer literal"),
100            _ => panic!("called `Arg::as_int()` on a non-integer variant"),
101        }
102    }
103
104    pub fn as_float<T: std::str::FromStr>(&self) -> T
105    where
106        T::Err: std::fmt::Display,
107    {
108        match self {
109            Self::Lit(Lit::Float(f)) => f.base10_parse().expect("invalid float literal"),
110            Self::Expr(
111                _,
112                syn::Expr::Lit(syn::ExprLit {
113                    lit: Lit::Float(f), ..
114                }),
115            ) => f.base10_parse().expect("invalid float literal"),
116            _ => panic!("called `Arg::as_float()` on a non-float variant"),
117        }
118    }
119
120    pub fn as_char(&self) -> char {
121        match self {
122            Self::Lit(Lit::Char(c)) => c.value(),
123            Self::Expr(
124                _,
125                syn::Expr::Lit(syn::ExprLit {
126                    lit: Lit::Char(c), ..
127                }),
128            ) => c.value(),
129            _ => panic!("called `Arg::as_char()` on a non-char variant"),
130        }
131    }
132
133    pub fn as_expr_lit(&self) -> Option<&Lit> {
134        match self {
135            Self::Lit(lit) => Some(lit),
136            Self::Expr(_, Expr::Lit(syn::ExprLit { lit, .. })) => Some(lit),
137            _ => None,
138        }
139    }
140
141    pub fn as_bool(&self) -> bool {
142        match self {
143            Self::Lit(Lit::Bool(b)) => b.value,
144            Self::Expr(
145                _,
146                syn::Expr::Lit(syn::ExprLit {
147                    lit: Lit::Bool(b), ..
148                }),
149            ) => b.value,
150            _ => panic!("called `Arg::as_bool()` on a non-bool variant"),
151        }
152    }
153}
154
155impl Parse for Arg {
156    fn parse(input: ParseStream) -> syn::Result<Self> {
157        if input.peek(Lit) || input.peek(syn::LitStr) {
158            return Ok(Self::Lit(input.parse()?));
159        }
160
161        let name: Ident = input.parse()?;
162
163        if input.peek(Token![=]) {
164            input.parse::<Token![=]>()?;
165            let expr: Expr = input.parse()?;
166            Ok(Self::Expr(name, expr))
167        } else if input.peek(syn::token::Paren) {
168            let content;
169            syn::parenthesized!(content in input);
170            let args: Args = content.parse()?;
171            Ok(Self::List(name, args))
172        } else {
173            Ok(Self::Flag(name))
174        }
175    }
176}
177
178impl ToTokens for Arg {
179    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
180        match self {
181            Self::Flag(name) => name.to_tokens(tokens),
182            Self::Expr(name, expr) => {
183                name.to_tokens(tokens);
184                tokens.append(proc_macro2::Punct::new('=', proc_macro2::Spacing::Alone));
185                expr.to_tokens(tokens);
186            }
187            Self::List(name, args) => {
188                name.to_tokens(tokens);
189                let inner = args.to_token_stream();
190                tokens.append(proc_macro2::Group::new(
191                    proc_macro2::Delimiter::Parenthesis,
192                    inner,
193                ));
194            }
195            Self::Lit(lit) => lit.to_tokens(tokens),
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    mod parse {
205        use super::*;
206
207        #[test]
208        fn flag() {
209            let arg: Arg = syn::parse_str("skip").unwrap();
210            assert!(arg.is_flag());
211            assert_eq!(arg.name().unwrap(), "skip");
212        }
213
214        #[test]
215        fn expr() {
216            let arg: Arg = syn::parse_str("rename = \"foo\"").unwrap();
217            assert!(arg.is_expr());
218            assert_eq!(arg.name().unwrap(), "rename");
219        }
220
221        #[test]
222        fn list() {
223            let arg: Arg = syn::parse_str("serde(rename_all = \"camelCase\")").unwrap();
224            assert!(arg.is_list());
225            assert_eq!(arg.name().unwrap(), "serde");
226            assert!(arg.as_args().has("rename_all"));
227        }
228
229        #[test]
230        fn lit_str() {
231            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
232            assert!(arg.is_lit());
233            assert!(arg.name().is_none());
234        }
235
236        #[test]
237        fn lit_int() {
238            let arg: Arg = syn::parse_str("42").unwrap();
239            assert!(arg.is_lit());
240        }
241    }
242
243    mod to_tokens {
244        use super::*;
245
246        #[test]
247        fn flag() {
248            let arg: Arg = syn::parse_str("skip").unwrap();
249            let output = arg.to_token_stream().to_string();
250            assert_eq!(output, "skip");
251        }
252
253        #[test]
254        fn expr() {
255            let arg: Arg = syn::parse_str("rename = \"foo\"").unwrap();
256            let output = arg.to_token_stream().to_string();
257            assert_eq!(output, "rename = \"foo\"");
258        }
259
260        #[test]
261        fn list() {
262            let arg: Arg = syn::parse_str("serde(skip)").unwrap();
263            let output = arg.to_token_stream().to_string();
264            assert_eq!(output, "serde (skip)");
265        }
266
267        #[test]
268        fn lit() {
269            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
270            let output = arg.to_token_stream().to_string();
271            assert_eq!(output, "\"hello\"");
272        }
273    }
274
275    mod accessors {
276        use super::*;
277
278        #[test]
279        #[should_panic(expected = "non-Expr")]
280        fn as_expr_panics_on_flag() {
281            let arg: Arg = syn::parse_str("skip").unwrap();
282            arg.as_expr();
283        }
284
285        #[test]
286        #[should_panic(expected = "non-List")]
287        fn as_args_panics_on_flag() {
288            let arg: Arg = syn::parse_str("skip").unwrap();
289            arg.as_args();
290        }
291
292        #[test]
293        #[should_panic(expected = "non-Lit")]
294        fn as_lit_panics_on_flag() {
295            let arg: Arg = syn::parse_str("skip").unwrap();
296            arg.as_lit();
297        }
298
299        #[test]
300        fn as_flag_returns_ident() {
301            let arg: Arg = syn::parse_str("skip").unwrap();
302            assert_eq!(arg.as_flag().to_string(), "skip");
303        }
304
305        #[test]
306        #[should_panic(expected = "non-Flag")]
307        fn as_flag_panics_on_expr() {
308            let arg: Arg = syn::parse_str("x = 1").unwrap();
309            arg.as_flag();
310        }
311
312        #[test]
313        fn as_str_from_lit() {
314            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
315            assert_eq!(arg.as_str(), "hello");
316        }
317
318        #[test]
319        fn as_str_from_expr() {
320            let arg: Arg = syn::parse_str("rename = \"foo\"").unwrap();
321            assert_eq!(arg.as_str(), "foo");
322        }
323
324        #[test]
325        #[should_panic(expected = "non-string")]
326        fn as_str_panics_on_flag() {
327            let arg: Arg = syn::parse_str("skip").unwrap();
328            arg.as_str();
329        }
330
331        #[test]
332        fn as_int_from_lit() {
333            let arg: Arg = syn::parse_str("42").unwrap();
334            assert_eq!(arg.as_int::<i64>(), 42i64);
335        }
336
337        #[test]
338        fn as_int_from_expr() {
339            let arg: Arg = syn::parse_str("count = 7").unwrap();
340            assert_eq!(arg.as_int::<i64>(), 7i64);
341        }
342
343        #[test]
344        #[should_panic(expected = "non-integer")]
345        fn as_int_panics_on_string_lit() {
346            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
347            arg.as_int::<i64>();
348        }
349
350        #[test]
351        fn as_char_from_lit() {
352            let arg: Arg = syn::parse_str("'x'").unwrap();
353            assert_eq!(arg.as_char(), 'x');
354        }
355
356        #[test]
357        #[should_panic(expected = "non-char")]
358        fn as_char_panics_on_flag() {
359            let arg: Arg = syn::parse_str("skip").unwrap();
360            arg.as_char();
361        }
362    }
363}