Skip to main content

zyn_core/meta/
args.rs

1use std::ops::Index;
2
3use quote::ToTokens;
4use quote::TokenStreamExt;
5
6use syn::Token;
7use syn::parse::Parse;
8use syn::parse::ParseStream;
9
10use super::Arg;
11
12#[derive(Clone, Default)]
13pub struct Args(Vec<Arg>);
14
15impl Args {
16    pub fn new() -> Self {
17        Self(Vec::new())
18    }
19
20    pub fn has(&self, name: &str) -> bool {
21        self.0
22            .iter()
23            .any(|arg| arg.name().is_some_and(|n| n == name))
24    }
25
26    pub fn get(&self, name: &str) -> Option<&Arg> {
27        self.0
28            .iter()
29            .find(|arg| arg.name().is_some_and(|n| n == name))
30    }
31
32    pub fn get_index(&self, index: usize) -> Option<&Arg> {
33        self.0.get(index)
34    }
35
36    pub fn iter(&self) -> impl Iterator<Item = &Arg> {
37        self.0.iter()
38    }
39
40    pub fn len(&self) -> usize {
41        self.0.len()
42    }
43
44    pub fn is_empty(&self) -> bool {
45        self.0.is_empty()
46    }
47
48    pub fn extend(&mut self, other: Args) {
49        self.0.extend(other.0);
50    }
51
52    pub fn merge(&self, other: &Args) -> Args {
53        let mut result: Vec<Arg> = Vec::new();
54
55        for arg in &self.0 {
56            if let Some(name) = arg.name()
57                && other.has(&name.to_string())
58            {
59                continue;
60            }
61
62            result.push(arg.clone());
63        }
64
65        for arg in &other.0 {
66            result.push(arg.clone());
67        }
68
69        Args(result)
70    }
71}
72
73impl Index<usize> for Args {
74    type Output = Arg;
75
76    fn index(&self, index: usize) -> &Self::Output {
77        &self.0[index]
78    }
79}
80
81impl IntoIterator for Args {
82    type Item = Arg;
83    type IntoIter = std::vec::IntoIter<Arg>;
84
85    fn into_iter(self) -> Self::IntoIter {
86        self.0.into_iter()
87    }
88}
89
90impl<'a> IntoIterator for &'a Args {
91    type Item = &'a Arg;
92    type IntoIter = std::slice::Iter<'a, Arg>;
93
94    fn into_iter(self) -> Self::IntoIter {
95        self.0.iter()
96    }
97}
98
99impl Parse for Args {
100    fn parse(input: ParseStream) -> syn::Result<Self> {
101        let mut args = Vec::new();
102
103        while !input.is_empty() {
104            args.push(input.parse::<Arg>()?);
105
106            if input.peek(Token![,]) {
107                input.parse::<Token![,]>()?;
108            }
109        }
110
111        Ok(Self(args))
112    }
113}
114
115impl ToTokens for Args {
116    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
117        for (i, arg) in self.0.iter().enumerate() {
118            if i > 0 {
119                tokens.append(proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
120            }
121
122            arg.to_tokens(tokens);
123        }
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    mod parse {
132        use super::*;
133
134        #[test]
135        fn empty() {
136            let args: Args = syn::parse_str("").unwrap();
137            assert!(args.is_empty());
138            assert_eq!(args.len(), 0);
139        }
140
141        #[test]
142        fn single_flag() {
143            let args: Args = syn::parse_str("skip").unwrap();
144            assert_eq!(args.len(), 1);
145            assert!(args.has("skip"));
146        }
147
148        #[test]
149        fn multiple() {
150            let args: Args = syn::parse_str("skip, rename = \"foo\", serde(flatten)").unwrap();
151            assert_eq!(args.len(), 3);
152            assert!(args.has("skip"));
153            assert!(args.has("rename"));
154            assert!(args.has("serde"));
155        }
156
157        #[test]
158        fn trailing_comma() {
159            let args: Args = syn::parse_str("skip, rename = \"foo\",").unwrap();
160            assert_eq!(args.len(), 2);
161        }
162
163        #[test]
164        fn lit() {
165            let args: Args = syn::parse_str("\"hello\"").unwrap();
166            assert_eq!(args.len(), 1);
167            assert!(args[0].is_lit());
168        }
169
170        #[test]
171        fn nested() {
172            let args: Args = syn::parse_str("outer(inner(deep = 1))").unwrap();
173            assert_eq!(args.len(), 1);
174
175            let outer = args.get("outer").unwrap().as_args();
176            assert!(outer.has("inner"));
177
178            let inner = outer.get("inner").unwrap().as_args();
179            assert!(inner.has("deep"));
180        }
181    }
182
183    mod query {
184        use super::*;
185
186        #[test]
187        fn has_existing() {
188            let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
189            assert!(args.has("skip"));
190            assert!(args.has("rename"));
191        }
192
193        #[test]
194        fn has_missing() {
195            let args: Args = syn::parse_str("skip").unwrap();
196            assert!(!args.has("rename"));
197        }
198
199        #[test]
200        fn get_existing() {
201            let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
202            assert!(args.get("skip").unwrap().is_flag());
203            assert!(args.get("rename").unwrap().is_expr());
204        }
205
206        #[test]
207        fn get_missing() {
208            let args: Args = syn::parse_str("skip").unwrap();
209            assert!(args.get("rename").is_none());
210        }
211
212        #[test]
213        fn index() {
214            let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
215            assert!(args[0].is_flag());
216            assert!(args[1].is_expr());
217        }
218
219        #[test]
220        fn iter() {
221            let args: Args = syn::parse_str("a, b, c").unwrap();
222            let names: Vec<String> = args
223                .iter()
224                .filter_map(|a| a.name())
225                .map(|n| n.to_string())
226                .collect();
227            assert_eq!(names, vec!["a", "b", "c"]);
228        }
229    }
230
231    mod merge {
232        use super::*;
233
234        #[test]
235        fn no_overlap() {
236            let a: Args = syn::parse_str("skip").unwrap();
237            let b: Args = syn::parse_str("rename = \"foo\"").unwrap();
238            let merged = a.merge(&b);
239            assert_eq!(merged.len(), 2);
240            assert!(merged.has("skip"));
241            assert!(merged.has("rename"));
242        }
243
244        #[test]
245        fn override_key() {
246            let a: Args = syn::parse_str("rename = \"foo\"").unwrap();
247            let b: Args = syn::parse_str("rename = \"bar\"").unwrap();
248            let merged = a.merge(&b);
249            assert_eq!(merged.len(), 1);
250        }
251
252        #[test]
253        fn extend_keeps_duplicates() {
254            let mut a: Args = syn::parse_str("skip").unwrap();
255            let b: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
256            a.extend(b);
257            assert_eq!(a.len(), 3);
258        }
259
260        #[test]
261        fn empty_merge() {
262            let a: Args = syn::parse_str("skip").unwrap();
263            let b: Args = syn::parse_str("").unwrap();
264            let merged = a.merge(&b);
265            assert_eq!(merged.len(), 1);
266            assert!(merged.has("skip"));
267        }
268    }
269
270    mod to_tokens {
271        use super::*;
272
273        #[test]
274        fn round_trip() {
275            let input = "skip , rename = \"foo\" , serde (flatten)";
276            let args: Args = syn::parse_str(input).unwrap();
277            let output = args.to_token_stream().to_string();
278            let reparsed: Args = syn::parse_str(&output).unwrap();
279            assert_eq!(reparsed.len(), 3);
280            assert!(reparsed.has("skip"));
281            assert!(reparsed.has("rename"));
282            assert!(reparsed.has("serde"));
283        }
284    }
285}