Skip to main content

zyn_core/meta/
arg.rs

1use proc_macro2::Ident;
2
3use quote::ToTokens;
4use quote::TokenStreamExt;
5
6use syn::Expr;
7use syn::Lit;
8use syn::Token;
9use syn::parse::Parse;
10use syn::parse::ParseStream;
11
12use super::Args;
13
14/// A single parsed attribute argument.
15///
16/// Represents one of four forms: a bare flag (`skip`), a key-value expression
17/// (`rename = "foo"`), a nested list (`serde(flatten)`), or a standalone literal (`"hello"`).
18#[derive(Clone)]
19pub enum Arg {
20    Flag(Ident),
21    Expr(Ident, Expr),
22    List(Ident, Args),
23    Lit(Lit),
24}
25
26impl Arg {
27    /// Returns the name of this argument, or `None` for standalone literals.
28    pub fn name(&self) -> Option<&Ident> {
29        match self {
30            Self::Flag(name) => Some(name),
31            Self::Expr(name, _) => Some(name),
32            Self::List(name, _) => Some(name),
33            Self::Lit(_) => None,
34        }
35    }
36
37    /// Returns `true` if this is a bare flag (e.g. `skip`).
38    pub fn is_flag(&self) -> bool {
39        matches!(self, Self::Flag(_))
40    }
41
42    /// Returns `true` if this is a key-value expression (e.g. `rename = "foo"`).
43    pub fn is_expr(&self) -> bool {
44        matches!(self, Self::Expr(_, _))
45    }
46
47    /// Returns `true` if this is a nested list (e.g. `serde(flatten)`).
48    pub fn is_list(&self) -> bool {
49        matches!(self, Self::List(_, _))
50    }
51
52    /// Returns `true` if this is a standalone literal.
53    pub fn is_lit(&self) -> bool {
54        matches!(self, Self::Lit(_))
55    }
56
57    /// Returns the expression value. Panics if not an `Expr` variant.
58    pub fn as_expr(&self) -> &Expr {
59        match self {
60            Self::Expr(_, expr) => expr,
61            _ => panic!("called `Arg::as_expr()` on a non-Expr variant"),
62        }
63    }
64
65    /// Returns the nested args. Panics if not a `List` variant.
66    pub fn as_args(&self) -> &Args {
67        match self {
68            Self::List(_, args) => args,
69            _ => panic!("called `Arg::as_args()` on a non-List variant"),
70        }
71    }
72
73    /// Returns the literal value. Panics if not a `Lit` variant.
74    pub fn as_lit(&self) -> &Lit {
75        match self {
76            Self::Lit(lit) => lit,
77            _ => panic!("called `Arg::as_lit()` on a non-Lit variant"),
78        }
79    }
80
81    /// Returns the flag identifier. Panics if not a `Flag` variant.
82    pub fn as_flag(&self) -> &Ident {
83        match self {
84            Self::Flag(i) => i,
85            _ => panic!("called `Arg::as_flag()` on a non-Flag variant"),
86        }
87    }
88
89    /// Extracts the string value. Panics if the value is not a string literal.
90    pub fn as_str(&self) -> String {
91        match self {
92            Self::Lit(Lit::Str(s)) => s.value(),
93            Self::Expr(
94                _,
95                syn::Expr::Lit(syn::ExprLit {
96                    lit: Lit::Str(s), ..
97                }),
98            ) => s.value(),
99            _ => panic!("called `Arg::as_str()` on a non-string variant"),
100        }
101    }
102
103    /// Parses the integer value. Panics if the value is not an integer literal.
104    pub fn as_int<T: std::str::FromStr>(&self) -> T
105    where
106        T::Err: std::fmt::Display,
107    {
108        match self {
109            Self::Lit(Lit::Int(i)) => i.base10_parse().expect("invalid integer literal"),
110            Self::Expr(
111                _,
112                syn::Expr::Lit(syn::ExprLit {
113                    lit: Lit::Int(i), ..
114                }),
115            ) => i.base10_parse().expect("invalid integer literal"),
116            _ => panic!("called `Arg::as_int()` on a non-integer variant"),
117        }
118    }
119
120    /// Parses the float value. Panics if the value is not a float literal.
121    pub fn as_float<T: std::str::FromStr>(&self) -> T
122    where
123        T::Err: std::fmt::Display,
124    {
125        match self {
126            Self::Lit(Lit::Float(f)) => f.base10_parse().expect("invalid float literal"),
127            Self::Expr(
128                _,
129                syn::Expr::Lit(syn::ExprLit {
130                    lit: Lit::Float(f), ..
131                }),
132            ) => f.base10_parse().expect("invalid float literal"),
133            _ => panic!("called `Arg::as_float()` on a non-float variant"),
134        }
135    }
136
137    /// Extracts the char value. Panics if the value is not a char literal.
138    pub fn as_char(&self) -> char {
139        match self {
140            Self::Lit(Lit::Char(c)) => c.value(),
141            Self::Expr(
142                _,
143                syn::Expr::Lit(syn::ExprLit {
144                    lit: Lit::Char(c), ..
145                }),
146            ) => c.value(),
147            _ => panic!("called `Arg::as_char()` on a non-char variant"),
148        }
149    }
150
151    /// Returns the literal if this is a `Lit` or an `Expr` containing a literal.
152    pub fn as_expr_lit(&self) -> Option<&Lit> {
153        match self {
154            Self::Lit(lit) => Some(lit),
155            Self::Expr(_, Expr::Lit(syn::ExprLit { lit, .. })) => Some(lit),
156            _ => None,
157        }
158    }
159
160    /// Extracts the bool value. Panics if the value is not a bool literal.
161    pub fn as_bool(&self) -> bool {
162        match self {
163            Self::Lit(Lit::Bool(b)) => b.value,
164            Self::Expr(
165                _,
166                syn::Expr::Lit(syn::ExprLit {
167                    lit: Lit::Bool(b), ..
168                }),
169            ) => b.value,
170            _ => panic!("called `Arg::as_bool()` on a non-bool variant"),
171        }
172    }
173}
174
175impl Parse for Arg {
176    fn parse(input: ParseStream) -> syn::Result<Self> {
177        if input.peek(Lit) || input.peek(syn::LitStr) {
178            return Ok(Self::Lit(input.parse()?));
179        }
180
181        let name: Ident = input.parse()?;
182
183        if input.peek(Token![=]) {
184            input.parse::<Token![=]>()?;
185            let expr: Expr = input.parse()?;
186            Ok(Self::Expr(name, expr))
187        } else if input.peek(syn::token::Paren) {
188            let content;
189            syn::parenthesized!(content in input);
190            let args: Args = content.parse()?;
191            Ok(Self::List(name, args))
192        } else {
193            Ok(Self::Flag(name))
194        }
195    }
196}
197
198impl ToTokens for Arg {
199    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
200        match self {
201            Self::Flag(name) => name.to_tokens(tokens),
202            Self::Expr(name, expr) => {
203                name.to_tokens(tokens);
204                tokens.append(proc_macro2::Punct::new('=', proc_macro2::Spacing::Alone));
205                expr.to_tokens(tokens);
206            }
207            Self::List(name, args) => {
208                name.to_tokens(tokens);
209                let inner = args.to_token_stream();
210                tokens.append(proc_macro2::Group::new(
211                    proc_macro2::Delimiter::Parenthesis,
212                    inner,
213                ));
214            }
215            Self::Lit(lit) => lit.to_tokens(tokens),
216        }
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    mod parse {
225        use super::*;
226
227        #[test]
228        fn flag() {
229            let arg: Arg = syn::parse_str("skip").unwrap();
230            assert!(arg.is_flag());
231            assert_eq!(arg.name().unwrap(), "skip");
232        }
233
234        #[test]
235        fn expr() {
236            let arg: Arg = syn::parse_str("rename = \"foo\"").unwrap();
237            assert!(arg.is_expr());
238            assert_eq!(arg.name().unwrap(), "rename");
239        }
240
241        #[test]
242        fn list() {
243            let arg: Arg = syn::parse_str("serde(rename_all = \"camelCase\")").unwrap();
244            assert!(arg.is_list());
245            assert_eq!(arg.name().unwrap(), "serde");
246            assert!(arg.as_args().has("rename_all"));
247        }
248
249        #[test]
250        fn lit_str() {
251            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
252            assert!(arg.is_lit());
253            assert!(arg.name().is_none());
254        }
255
256        #[test]
257        fn lit_int() {
258            let arg: Arg = syn::parse_str("42").unwrap();
259            assert!(arg.is_lit());
260        }
261    }
262
263    mod to_tokens {
264        use super::*;
265
266        #[test]
267        fn flag() {
268            let arg: Arg = syn::parse_str("skip").unwrap();
269            let output = arg.to_token_stream().to_string();
270            assert_eq!(output, "skip");
271        }
272
273        #[test]
274        fn expr() {
275            let arg: Arg = syn::parse_str("rename = \"foo\"").unwrap();
276            let output = arg.to_token_stream().to_string();
277            assert_eq!(output, "rename = \"foo\"");
278        }
279
280        #[test]
281        fn list() {
282            let arg: Arg = syn::parse_str("serde(skip)").unwrap();
283            let output = arg.to_token_stream().to_string();
284            assert_eq!(output, "serde (skip)");
285        }
286
287        #[test]
288        fn lit() {
289            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
290            let output = arg.to_token_stream().to_string();
291            assert_eq!(output, "\"hello\"");
292        }
293    }
294
295    mod accessors {
296        use super::*;
297
298        #[test]
299        #[should_panic(expected = "non-Expr")]
300        fn as_expr_panics_on_flag() {
301            let arg: Arg = syn::parse_str("skip").unwrap();
302            arg.as_expr();
303        }
304
305        #[test]
306        #[should_panic(expected = "non-List")]
307        fn as_args_panics_on_flag() {
308            let arg: Arg = syn::parse_str("skip").unwrap();
309            arg.as_args();
310        }
311
312        #[test]
313        #[should_panic(expected = "non-Lit")]
314        fn as_lit_panics_on_flag() {
315            let arg: Arg = syn::parse_str("skip").unwrap();
316            arg.as_lit();
317        }
318
319        #[test]
320        fn as_flag_returns_ident() {
321            let arg: Arg = syn::parse_str("skip").unwrap();
322            assert_eq!(arg.as_flag().to_string(), "skip");
323        }
324
325        #[test]
326        #[should_panic(expected = "non-Flag")]
327        fn as_flag_panics_on_expr() {
328            let arg: Arg = syn::parse_str("x = 1").unwrap();
329            arg.as_flag();
330        }
331
332        #[test]
333        fn as_str_from_lit() {
334            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
335            assert_eq!(arg.as_str(), "hello");
336        }
337
338        #[test]
339        fn as_str_from_expr() {
340            let arg: Arg = syn::parse_str("rename = \"foo\"").unwrap();
341            assert_eq!(arg.as_str(), "foo");
342        }
343
344        #[test]
345        #[should_panic(expected = "non-string")]
346        fn as_str_panics_on_flag() {
347            let arg: Arg = syn::parse_str("skip").unwrap();
348            arg.as_str();
349        }
350
351        #[test]
352        fn as_int_from_lit() {
353            let arg: Arg = syn::parse_str("42").unwrap();
354            assert_eq!(arg.as_int::<i64>(), 42i64);
355        }
356
357        #[test]
358        fn as_int_from_expr() {
359            let arg: Arg = syn::parse_str("count = 7").unwrap();
360            assert_eq!(arg.as_int::<i64>(), 7i64);
361        }
362
363        #[test]
364        #[should_panic(expected = "non-integer")]
365        fn as_int_panics_on_string_lit() {
366            let arg: Arg = syn::parse_str("\"hello\"").unwrap();
367            arg.as_int::<i64>();
368        }
369
370        #[test]
371        fn as_char_from_lit() {
372            let arg: Arg = syn::parse_str("'x'").unwrap();
373            assert_eq!(arg.as_char(), 'x');
374        }
375
376        #[test]
377        #[should_panic(expected = "non-char")]
378        fn as_char_panics_on_flag() {
379            let arg: Arg = syn::parse_str("skip").unwrap();
380            arg.as_char();
381        }
382    }
383}