Skip to main content

zyn_core/meta/
args.rs

1use std::ops::Index;
2
3use quote::ToTokens;
4use quote::TokenStreamExt;
5
6use syn::Token;
7use syn::parse::Parse;
8use syn::parse::ParseStream;
9
10use super::Arg;
11
12#[derive(Default)]
13pub struct Args(Vec<Arg>);
14
15impl Args {
16    pub fn new() -> Self {
17        Self(Vec::new())
18    }
19
20    pub fn has(&self, name: &str) -> bool {
21        self.0
22            .iter()
23            .any(|arg| arg.name().is_some_and(|n| n == name))
24    }
25
26    pub fn get(&self, name: &str) -> Option<&Arg> {
27        self.0
28            .iter()
29            .find(|arg| arg.name().is_some_and(|n| n == name))
30    }
31
32    pub fn get_index(&self, index: usize) -> Option<&Arg> {
33        self.0.get(index)
34    }
35
36    pub fn iter(&self) -> impl Iterator<Item = &Arg> {
37        self.0.iter()
38    }
39
40    pub fn len(&self) -> usize {
41        self.0.len()
42    }
43
44    pub fn is_empty(&self) -> bool {
45        self.0.is_empty()
46    }
47
48    pub fn extend(&mut self, other: Args) {
49        self.0.extend(other.0);
50    }
51
52    pub fn merge(&self, other: &Args) -> Args {
53        let mut result: Vec<Arg> = Vec::new();
54
55        for arg in &self.0 {
56            if let Some(name) = arg.name()
57                && other.has(&name.to_string())
58            {
59                continue;
60            }
61
62            result.push(clone_arg(arg));
63        }
64
65        for arg in &other.0 {
66            result.push(clone_arg(arg));
67        }
68
69        Args(result)
70    }
71}
72
73fn clone_arg(arg: &Arg) -> Arg {
74    syn::parse2(arg.to_token_stream()).unwrap()
75}
76
77impl Index<usize> for Args {
78    type Output = Arg;
79
80    fn index(&self, index: usize) -> &Self::Output {
81        &self.0[index]
82    }
83}
84
85impl IntoIterator for Args {
86    type Item = Arg;
87    type IntoIter = std::vec::IntoIter<Arg>;
88
89    fn into_iter(self) -> Self::IntoIter {
90        self.0.into_iter()
91    }
92}
93
94impl<'a> IntoIterator for &'a Args {
95    type Item = &'a Arg;
96    type IntoIter = std::slice::Iter<'a, Arg>;
97
98    fn into_iter(self) -> Self::IntoIter {
99        self.0.iter()
100    }
101}
102
103impl Parse for Args {
104    fn parse(input: ParseStream) -> syn::Result<Self> {
105        let mut args = Vec::new();
106
107        while !input.is_empty() {
108            args.push(input.parse::<Arg>()?);
109
110            if input.peek(Token![,]) {
111                input.parse::<Token![,]>()?;
112            }
113        }
114
115        Ok(Self(args))
116    }
117}
118
119impl ToTokens for Args {
120    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
121        for (i, arg) in self.0.iter().enumerate() {
122            if i > 0 {
123                tokens.append(proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
124            }
125
126            arg.to_tokens(tokens);
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    mod parse {
136        use super::*;
137
138        #[test]
139        fn empty() {
140            let args: Args = syn::parse_str("").unwrap();
141            assert!(args.is_empty());
142            assert_eq!(args.len(), 0);
143        }
144
145        #[test]
146        fn single_flag() {
147            let args: Args = syn::parse_str("skip").unwrap();
148            assert_eq!(args.len(), 1);
149            assert!(args.has("skip"));
150        }
151
152        #[test]
153        fn multiple() {
154            let args: Args = syn::parse_str("skip, rename = \"foo\", serde(flatten)").unwrap();
155            assert_eq!(args.len(), 3);
156            assert!(args.has("skip"));
157            assert!(args.has("rename"));
158            assert!(args.has("serde"));
159        }
160
161        #[test]
162        fn trailing_comma() {
163            let args: Args = syn::parse_str("skip, rename = \"foo\",").unwrap();
164            assert_eq!(args.len(), 2);
165        }
166
167        #[test]
168        fn lit() {
169            let args: Args = syn::parse_str("\"hello\"").unwrap();
170            assert_eq!(args.len(), 1);
171            assert!(args[0].is_lit());
172        }
173
174        #[test]
175        fn nested() {
176            let args: Args = syn::parse_str("outer(inner(deep = 1))").unwrap();
177            assert_eq!(args.len(), 1);
178
179            let outer = args.get("outer").unwrap().as_args();
180            assert!(outer.has("inner"));
181
182            let inner = outer.get("inner").unwrap().as_args();
183            assert!(inner.has("deep"));
184        }
185    }
186
187    mod query {
188        use super::*;
189
190        #[test]
191        fn has_existing() {
192            let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
193            assert!(args.has("skip"));
194            assert!(args.has("rename"));
195        }
196
197        #[test]
198        fn has_missing() {
199            let args: Args = syn::parse_str("skip").unwrap();
200            assert!(!args.has("rename"));
201        }
202
203        #[test]
204        fn get_existing() {
205            let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
206            assert!(args.get("skip").unwrap().is_flag());
207            assert!(args.get("rename").unwrap().is_expr());
208        }
209
210        #[test]
211        fn get_missing() {
212            let args: Args = syn::parse_str("skip").unwrap();
213            assert!(args.get("rename").is_none());
214        }
215
216        #[test]
217        fn index() {
218            let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
219            assert!(args[0].is_flag());
220            assert!(args[1].is_expr());
221        }
222
223        #[test]
224        fn iter() {
225            let args: Args = syn::parse_str("a, b, c").unwrap();
226            let names: Vec<String> = args
227                .iter()
228                .filter_map(|a| a.name())
229                .map(|n| n.to_string())
230                .collect();
231            assert_eq!(names, vec!["a", "b", "c"]);
232        }
233    }
234
235    mod merge {
236        use super::*;
237
238        #[test]
239        fn no_overlap() {
240            let a: Args = syn::parse_str("skip").unwrap();
241            let b: Args = syn::parse_str("rename = \"foo\"").unwrap();
242            let merged = a.merge(&b);
243            assert_eq!(merged.len(), 2);
244            assert!(merged.has("skip"));
245            assert!(merged.has("rename"));
246        }
247
248        #[test]
249        fn override_key() {
250            let a: Args = syn::parse_str("rename = \"foo\"").unwrap();
251            let b: Args = syn::parse_str("rename = \"bar\"").unwrap();
252            let merged = a.merge(&b);
253            assert_eq!(merged.len(), 1);
254        }
255
256        #[test]
257        fn extend_keeps_duplicates() {
258            let mut a: Args = syn::parse_str("skip").unwrap();
259            let b: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
260            a.extend(b);
261            assert_eq!(a.len(), 3);
262        }
263
264        #[test]
265        fn empty_merge() {
266            let a: Args = syn::parse_str("skip").unwrap();
267            let b: Args = syn::parse_str("").unwrap();
268            let merged = a.merge(&b);
269            assert_eq!(merged.len(), 1);
270            assert!(merged.has("skip"));
271        }
272    }
273
274    mod to_tokens {
275        use super::*;
276
277        #[test]
278        fn round_trip() {
279            let input = "skip , rename = \"foo\" , serde (flatten)";
280            let args: Args = syn::parse_str(input).unwrap();
281            let output = args.to_token_stream().to_string();
282            let reparsed: Args = syn::parse_str(&output).unwrap();
283            assert_eq!(reparsed.len(), 3);
284            assert!(reparsed.has("skip"));
285            assert!(reparsed.has("rename"));
286            assert!(reparsed.has("serde"));
287        }
288    }
289}