Skip to main content

zyn_core/meta/
args.rs

1use std::ops::Index;
2
3use quote::ToTokens;
4use quote::TokenStreamExt;
5
6use syn::Token;
7use syn::parse::Parse;
8use syn::parse::ParseStream;
9
10use super::Arg;
11
12/// A parsed list of attribute arguments.
13///
14/// Represents the comma-separated contents inside an attribute
15/// (e.g. the `skip, rename = "foo"` in `#[my_attr(skip, rename = "foo")]`).
16#[derive(Clone, Default)]
17pub struct Args(Vec<Arg>);
18
19impl Args {
20    /// Creates an empty argument list.
21    pub fn new() -> Self {
22        Self(Vec::new())
23    }
24
25    /// Returns `true` if an argument with the given name exists.
26    pub fn has(&self, name: &str) -> bool {
27        self.0
28            .iter()
29            .any(|arg| arg.name().is_some_and(|n| n == name))
30    }
31
32    /// Returns the first argument with the given name.
33    pub fn get(&self, name: &str) -> Option<&Arg> {
34        self.0
35            .iter()
36            .find(|arg| arg.name().is_some_and(|n| n == name))
37    }
38
39    /// Returns the argument at the given position.
40    pub fn get_index(&self, index: usize) -> Option<&Arg> {
41        self.0.get(index)
42    }
43
44    /// Returns an iterator over the arguments.
45    pub fn iter(&self) -> impl Iterator<Item = &Arg> {
46        self.0.iter()
47    }
48
49    /// Returns the number of arguments.
50    pub fn len(&self) -> usize {
51        self.0.len()
52    }
53
54    /// Returns `true` if there are no arguments.
55    pub fn is_empty(&self) -> bool {
56        self.0.is_empty()
57    }
58
59    /// Appends all arguments from `other`, keeping duplicates.
60    pub fn extend(&mut self, other: Args) {
61        self.0.extend(other.0);
62    }
63
64    /// Merges two argument lists, with `other` taking precedence on name conflicts.
65    pub fn merge(&self, other: &Args) -> Args {
66        let mut result: Vec<Arg> = Vec::new();
67
68        for arg in &self.0 {
69            if let Some(name) = arg.name()
70                && other.has(&name.to_string())
71            {
72                continue;
73            }
74
75            result.push(arg.clone());
76        }
77
78        for arg in &other.0 {
79            result.push(arg.clone());
80        }
81
82        Args(result)
83    }
84}
85
86impl Index<usize> for Args {
87    type Output = Arg;
88
89    fn index(&self, index: usize) -> &Self::Output {
90        &self.0[index]
91    }
92}
93
94impl IntoIterator for Args {
95    type Item = Arg;
96    type IntoIter = std::vec::IntoIter<Arg>;
97
98    fn into_iter(self) -> Self::IntoIter {
99        self.0.into_iter()
100    }
101}
102
103impl<'a> IntoIterator for &'a Args {
104    type Item = &'a Arg;
105    type IntoIter = std::slice::Iter<'a, Arg>;
106
107    fn into_iter(self) -> Self::IntoIter {
108        self.0.iter()
109    }
110}
111
112impl Parse for Args {
113    fn parse(input: ParseStream) -> syn::Result<Self> {
114        let mut args = Vec::new();
115
116        while !input.is_empty() {
117            args.push(input.parse::<Arg>()?);
118
119            if input.peek(Token![,]) {
120                input.parse::<Token![,]>()?;
121            }
122        }
123
124        Ok(Self(args))
125    }
126}
127
128impl ToTokens for Args {
129    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
130        for (i, arg) in self.0.iter().enumerate() {
131            if i > 0 {
132                tokens.append(proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
133            }
134
135            arg.to_tokens(tokens);
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    mod parse {
145        use super::*;
146
147        #[test]
148        fn empty() {
149            let args: Args = syn::parse_str("").unwrap();
150            assert!(args.is_empty());
151            assert_eq!(args.len(), 0);
152        }
153
154        #[test]
155        fn single_flag() {
156            let args: Args = syn::parse_str("skip").unwrap();
157            assert_eq!(args.len(), 1);
158            assert!(args.has("skip"));
159        }
160
161        #[test]
162        fn multiple() {
163            let args: Args = syn::parse_str("skip, rename = \"foo\", serde(flatten)").unwrap();
164            assert_eq!(args.len(), 3);
165            assert!(args.has("skip"));
166            assert!(args.has("rename"));
167            assert!(args.has("serde"));
168        }
169
170        #[test]
171        fn trailing_comma() {
172            let args: Args = syn::parse_str("skip, rename = \"foo\",").unwrap();
173            assert_eq!(args.len(), 2);
174        }
175
176        #[test]
177        fn lit() {
178            let args: Args = syn::parse_str("\"hello\"").unwrap();
179            assert_eq!(args.len(), 1);
180            assert!(args[0].is_lit());
181        }
182
183        #[test]
184        fn nested() {
185            let args: Args = syn::parse_str("outer(inner(deep = 1))").unwrap();
186            assert_eq!(args.len(), 1);
187
188            let outer = args.get("outer").unwrap().as_args();
189            assert!(outer.has("inner"));
190
191            let inner = outer.get("inner").unwrap().as_args();
192            assert!(inner.has("deep"));
193        }
194    }
195
196    mod query {
197        use super::*;
198
199        #[test]
200        fn has_existing() {
201            let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
202            assert!(args.has("skip"));
203            assert!(args.has("rename"));
204        }
205
206        #[test]
207        fn has_missing() {
208            let args: Args = syn::parse_str("skip").unwrap();
209            assert!(!args.has("rename"));
210        }
211
212        #[test]
213        fn get_existing() {
214            let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
215            assert!(args.get("skip").unwrap().is_flag());
216            assert!(args.get("rename").unwrap().is_expr());
217        }
218
219        #[test]
220        fn get_missing() {
221            let args: Args = syn::parse_str("skip").unwrap();
222            assert!(args.get("rename").is_none());
223        }
224
225        #[test]
226        fn index() {
227            let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
228            assert!(args[0].is_flag());
229            assert!(args[1].is_expr());
230        }
231
232        #[test]
233        fn iter() {
234            let args: Args = syn::parse_str("a, b, c").unwrap();
235            let names: Vec<String> = args
236                .iter()
237                .filter_map(|a| a.name())
238                .map(|n| n.to_string())
239                .collect();
240            assert_eq!(names, vec!["a", "b", "c"]);
241        }
242    }
243
244    mod merge {
245        use super::*;
246
247        #[test]
248        fn no_overlap() {
249            let a: Args = syn::parse_str("skip").unwrap();
250            let b: Args = syn::parse_str("rename = \"foo\"").unwrap();
251            let merged = a.merge(&b);
252            assert_eq!(merged.len(), 2);
253            assert!(merged.has("skip"));
254            assert!(merged.has("rename"));
255        }
256
257        #[test]
258        fn override_key() {
259            let a: Args = syn::parse_str("rename = \"foo\"").unwrap();
260            let b: Args = syn::parse_str("rename = \"bar\"").unwrap();
261            let merged = a.merge(&b);
262            assert_eq!(merged.len(), 1);
263        }
264
265        #[test]
266        fn extend_keeps_duplicates() {
267            let mut a: Args = syn::parse_str("skip").unwrap();
268            let b: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
269            a.extend(b);
270            assert_eq!(a.len(), 3);
271        }
272
273        #[test]
274        fn empty_merge() {
275            let a: Args = syn::parse_str("skip").unwrap();
276            let b: Args = syn::parse_str("").unwrap();
277            let merged = a.merge(&b);
278            assert_eq!(merged.len(), 1);
279            assert!(merged.has("skip"));
280        }
281    }
282
283    mod to_tokens {
284        use super::*;
285
286        #[test]
287        fn round_trip() {
288            let input = "skip , rename = \"foo\" , serde (flatten)";
289            let args: Args = syn::parse_str(input).unwrap();
290            let output = args.to_token_stream().to_string();
291            let reparsed: Args = syn::parse_str(&output).unwrap();
292            assert_eq!(reparsed.len(), 3);
293            assert!(reparsed.has("skip"));
294            assert!(reparsed.has("rename"));
295            assert!(reparsed.has("serde"));
296        }
297    }
298}