1use std::ops::Index;
2
3use quote::ToTokens;
4use quote::TokenStreamExt;
5
6use syn::Token;
7use syn::parse::Parse;
8use syn::parse::ParseStream;
9
10use super::Arg;
11
12#[derive(Clone, Default)]
13pub struct Args(Vec<Arg>);
14
15impl Args {
16 pub fn new() -> Self {
17 Self(Vec::new())
18 }
19
20 pub fn has(&self, name: &str) -> bool {
21 self.0
22 .iter()
23 .any(|arg| arg.name().is_some_and(|n| n == name))
24 }
25
26 pub fn get(&self, name: &str) -> Option<&Arg> {
27 self.0
28 .iter()
29 .find(|arg| arg.name().is_some_and(|n| n == name))
30 }
31
32 pub fn get_index(&self, index: usize) -> Option<&Arg> {
33 self.0.get(index)
34 }
35
36 pub fn iter(&self) -> impl Iterator<Item = &Arg> {
37 self.0.iter()
38 }
39
40 pub fn len(&self) -> usize {
41 self.0.len()
42 }
43
44 pub fn is_empty(&self) -> bool {
45 self.0.is_empty()
46 }
47
48 pub fn extend(&mut self, other: Args) {
49 self.0.extend(other.0);
50 }
51
52 pub fn merge(&self, other: &Args) -> Args {
53 let mut result: Vec<Arg> = Vec::new();
54
55 for arg in &self.0 {
56 if let Some(name) = arg.name()
57 && other.has(&name.to_string())
58 {
59 continue;
60 }
61
62 result.push(arg.clone());
63 }
64
65 for arg in &other.0 {
66 result.push(arg.clone());
67 }
68
69 Args(result)
70 }
71}
72
73impl Index<usize> for Args {
74 type Output = Arg;
75
76 fn index(&self, index: usize) -> &Self::Output {
77 &self.0[index]
78 }
79}
80
81impl IntoIterator for Args {
82 type Item = Arg;
83 type IntoIter = std::vec::IntoIter<Arg>;
84
85 fn into_iter(self) -> Self::IntoIter {
86 self.0.into_iter()
87 }
88}
89
90impl<'a> IntoIterator for &'a Args {
91 type Item = &'a Arg;
92 type IntoIter = std::slice::Iter<'a, Arg>;
93
94 fn into_iter(self) -> Self::IntoIter {
95 self.0.iter()
96 }
97}
98
99impl Parse for Args {
100 fn parse(input: ParseStream) -> syn::Result<Self> {
101 let mut args = Vec::new();
102
103 while !input.is_empty() {
104 args.push(input.parse::<Arg>()?);
105
106 if input.peek(Token![,]) {
107 input.parse::<Token![,]>()?;
108 }
109 }
110
111 Ok(Self(args))
112 }
113}
114
115impl ToTokens for Args {
116 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
117 for (i, arg) in self.0.iter().enumerate() {
118 if i > 0 {
119 tokens.append(proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone));
120 }
121
122 arg.to_tokens(tokens);
123 }
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 mod parse {
132 use super::*;
133
134 #[test]
135 fn empty() {
136 let args: Args = syn::parse_str("").unwrap();
137 assert!(args.is_empty());
138 assert_eq!(args.len(), 0);
139 }
140
141 #[test]
142 fn single_flag() {
143 let args: Args = syn::parse_str("skip").unwrap();
144 assert_eq!(args.len(), 1);
145 assert!(args.has("skip"));
146 }
147
148 #[test]
149 fn multiple() {
150 let args: Args = syn::parse_str("skip, rename = \"foo\", serde(flatten)").unwrap();
151 assert_eq!(args.len(), 3);
152 assert!(args.has("skip"));
153 assert!(args.has("rename"));
154 assert!(args.has("serde"));
155 }
156
157 #[test]
158 fn trailing_comma() {
159 let args: Args = syn::parse_str("skip, rename = \"foo\",").unwrap();
160 assert_eq!(args.len(), 2);
161 }
162
163 #[test]
164 fn lit() {
165 let args: Args = syn::parse_str("\"hello\"").unwrap();
166 assert_eq!(args.len(), 1);
167 assert!(args[0].is_lit());
168 }
169
170 #[test]
171 fn nested() {
172 let args: Args = syn::parse_str("outer(inner(deep = 1))").unwrap();
173 assert_eq!(args.len(), 1);
174
175 let outer = args.get("outer").unwrap().as_args();
176 assert!(outer.has("inner"));
177
178 let inner = outer.get("inner").unwrap().as_args();
179 assert!(inner.has("deep"));
180 }
181 }
182
183 mod query {
184 use super::*;
185
186 #[test]
187 fn has_existing() {
188 let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
189 assert!(args.has("skip"));
190 assert!(args.has("rename"));
191 }
192
193 #[test]
194 fn has_missing() {
195 let args: Args = syn::parse_str("skip").unwrap();
196 assert!(!args.has("rename"));
197 }
198
199 #[test]
200 fn get_existing() {
201 let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
202 assert!(args.get("skip").unwrap().is_flag());
203 assert!(args.get("rename").unwrap().is_expr());
204 }
205
206 #[test]
207 fn get_missing() {
208 let args: Args = syn::parse_str("skip").unwrap();
209 assert!(args.get("rename").is_none());
210 }
211
212 #[test]
213 fn index() {
214 let args: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
215 assert!(args[0].is_flag());
216 assert!(args[1].is_expr());
217 }
218
219 #[test]
220 fn iter() {
221 let args: Args = syn::parse_str("a, b, c").unwrap();
222 let names: Vec<String> = args
223 .iter()
224 .filter_map(|a| a.name())
225 .map(|n| n.to_string())
226 .collect();
227 assert_eq!(names, vec!["a", "b", "c"]);
228 }
229 }
230
231 mod merge {
232 use super::*;
233
234 #[test]
235 fn no_overlap() {
236 let a: Args = syn::parse_str("skip").unwrap();
237 let b: Args = syn::parse_str("rename = \"foo\"").unwrap();
238 let merged = a.merge(&b);
239 assert_eq!(merged.len(), 2);
240 assert!(merged.has("skip"));
241 assert!(merged.has("rename"));
242 }
243
244 #[test]
245 fn override_key() {
246 let a: Args = syn::parse_str("rename = \"foo\"").unwrap();
247 let b: Args = syn::parse_str("rename = \"bar\"").unwrap();
248 let merged = a.merge(&b);
249 assert_eq!(merged.len(), 1);
250 }
251
252 #[test]
253 fn extend_keeps_duplicates() {
254 let mut a: Args = syn::parse_str("skip").unwrap();
255 let b: Args = syn::parse_str("skip, rename = \"foo\"").unwrap();
256 a.extend(b);
257 assert_eq!(a.len(), 3);
258 }
259
260 #[test]
261 fn empty_merge() {
262 let a: Args = syn::parse_str("skip").unwrap();
263 let b: Args = syn::parse_str("").unwrap();
264 let merged = a.merge(&b);
265 assert_eq!(merged.len(), 1);
266 assert!(merged.has("skip"));
267 }
268 }
269
270 mod to_tokens {
271 use super::*;
272
273 #[test]
274 fn round_trip() {
275 let input = "skip , rename = \"foo\" , serde (flatten)";
276 let args: Args = syn::parse_str(input).unwrap();
277 let output = args.to_token_stream().to_string();
278 let reparsed: Args = syn::parse_str(&output).unwrap();
279 assert_eq!(reparsed.len(), 3);
280 assert!(reparsed.has("skip"));
281 assert!(reparsed.has("rename"));
282 assert!(reparsed.has("serde"));
283 }
284 }
285}