1use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::iter::IntoIterator;
18
19use proc_macro2::Span;
20use proc_macro2::TokenStream;
21use quote::ToTokens;
22use quote::quote;
23use syn::Attribute;
24use syn::Data;
25use syn::DataEnum;
26use syn::DataStruct;
27use syn::DeriveInput;
28use syn::Error;
29use syn::Expr;
30use syn::Field;
31use syn::Fields;
32use syn::Ident;
33use syn::Lit;
34use syn::LitStr;
35use syn::Member;
36use syn::Meta;
37use syn::MetaList;
38use syn::Path;
39use syn::Result;
40use syn::Token;
41use syn::Variant;
42use syn::parse_macro_input;
43use syn::parse_quote;
44use syn::punctuated::Punctuated;
45use syn::spanned::Spanned;
46use syn::token::Mut;
47
48#[proc_macro_derive(Traversable, attributes(traverse))]
49pub fn derive_traversable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
50 expand_with(input, |stream| impl_traversable(stream, false))
51}
52
53#[proc_macro_derive(TraversableMut, attributes(traverse))]
54pub fn derive_traversable_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
55 expand_with(input, |stream| impl_traversable(stream, true))
56}
57
58fn expand_with(
59 input: proc_macro::TokenStream,
60 handler: impl Fn(DeriveInput) -> Result<TokenStream>,
61) -> proc_macro::TokenStream {
62 let input = parse_macro_input!(input as DeriveInput);
63 handler(input)
64 .unwrap_or_else(|error| error.to_compile_error())
65 .into()
66}
67
68fn extract_meta(attrs: Vec<Attribute>, attr_name: &str) -> Result<Option<Meta>> {
69 let macro_attrs = attrs
70 .into_iter()
71 .filter(|attr| attr.path().is_ident(attr_name))
72 .collect::<Vec<Attribute>>();
73
74 if let Some(second) = macro_attrs.get(2) {
75 return Err(Error::new_spanned(second, "duplicate attribute"));
76 }
77
78 macro_attrs
79 .first()
80 .map(|attr| Ok(attr.meta.clone()))
81 .transpose()
82}
83
84#[derive(Default)]
85struct Params(HashMap<Path, Meta>);
86
87impl Params {
88 fn from_attrs(attrs: Vec<Attribute>, attr_name: &str) -> Result<Self> {
89 Ok(extract_meta(attrs, attr_name)?
90 .map(|meta| {
91 if let Meta::List(meta_list) = meta {
92 Self::from_meta_list(meta_list)
93 } else {
94 Err(Error::new_spanned(meta, "invalid attribute"))
95 }
96 })
97 .transpose()?
98 .unwrap_or_default())
99 }
100
101 fn from_meta_list(meta_list: MetaList) -> Result<Self> {
102 let mut params = HashMap::new();
103 let nested = meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
104 for meta in nested {
105 let path = meta.path();
106 let entry = params.entry(path.clone());
107 if matches!(entry, Entry::Occupied(_)) {
108 return Err(Error::new_spanned(path, "duplicate parameter"));
109 }
110 entry.or_insert(meta);
111 }
112 Ok(Self(params))
113 }
114
115 fn validate(&self, allowed_params: &[&str]) -> Result<()> {
116 for path in self.0.keys() {
117 if !allowed_params
118 .iter()
119 .any(|allowed_param| path.is_ident(allowed_param))
120 {
121 return Err(Error::new_spanned(
122 path,
123 format!(
124 "unknown parameter, supported: {}",
125 allowed_params.join(", ")
126 ),
127 ));
128 }
129 }
130 Ok(())
131 }
132
133 fn param(&mut self, name: &str) -> Result<Option<Param>> {
134 self.0
135 .remove(&Ident::new(name, Span::call_site()).into())
136 .map(Param::from_meta)
137 .transpose()
138 }
139}
140
141impl Iterator for Params {
142 type Item = Result<Param>;
143 fn next(&mut self) -> Option<Self::Item> {
144 self.0
145 .keys()
146 .next()
147 .cloned()
148 .map(|path| Param::from_meta(self.0.remove(&path).unwrap()))
149 }
150}
151
152enum Param {
153 Unit(Span),
154 StringLiteral(Span, LitStr),
155 NestedParams(Span),
156}
157
158impl Param {
159 fn from_meta(meta: Meta) -> Result<Self> {
160 let span = meta.span();
161 match meta {
162 Meta::Path(_) => Ok(Param::Unit(span)),
163 Meta::List(_) => Ok(Param::NestedParams(span)),
164 Meta::NameValue(name_value) => {
165 if let Expr::Lit(expr_lit) = &name_value.value {
166 if let Lit::Str(lit_str) = &expr_lit.lit {
167 Ok(Param::StringLiteral(span, lit_str.clone()))
168 } else {
169 Err(Error::new_spanned(name_value, "invalid parameter"))
170 }
171 } else {
172 Err(Error::new_spanned(name_value, "invalid parameter"))
173 }
174 }
175 }
176 }
177
178 fn span(&self) -> Span {
179 match self {
180 Self::Unit(span) | Self::StringLiteral(span, _) | Self::NestedParams(span) => *span,
181 }
182 }
183
184 fn unit(self) -> Result<()> {
185 if let Self::Unit(_) = self {
186 Ok(())
187 } else {
188 Err(Error::new(self.span(), "invalid parameter"))
189 }
190 }
191
192 fn string_literal(self) -> Result<LitStr> {
193 if let Self::StringLiteral(_, lit_str) = self {
194 Ok(lit_str)
195 } else {
196 Err(Error::new(self.span(), "invalid parameter"))
197 }
198 }
199}
200
201#[inline(always)]
202fn resolve_crate_name() -> Path {
203 parse_quote!(::traversable)
204}
205
206fn impl_traversable(input: DeriveInput, mutable: bool) -> Result<TokenStream> {
207 let mut params = Params::from_attrs(input.attrs, "traverse")?;
208 params.validate(&["skip"])?;
209
210 let skip_visit_self = params
211 .param("skip")?
212 .map(Param::unit)
213 .transpose()?
214 .is_some();
215
216 let name = input.ident;
217 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
218
219 let visitor = Ident::new(
220 if mutable { "VisitorMut" } else { "Visitor" },
221 Span::call_site(),
222 );
223
224 let enter_method = Ident::new(
225 if mutable { "enter_mut" } else { "enter" },
226 Span::call_site(),
227 );
228
229 let leave_method = Ident::new(
230 if mutable { "leave_mut" } else { "leave" },
231 Span::call_site(),
232 );
233
234 let crate_name = resolve_crate_name();
235
236 let enter_self = if skip_visit_self {
237 None
238 } else {
239 Some(quote! {
240 #crate_name::#visitor::#enter_method(visitor, self)?;
241 })
242 };
243
244 let leave_self = if skip_visit_self {
245 None
246 } else {
247 Some(quote! {
248 #crate_name::#visitor::#leave_method(visitor, self)?;
249 })
250 };
251
252 let traverse_fields = match input.data {
253 Data::Struct(struct_) => traverse_struct(struct_, mutable),
254 Data::Enum(enum_) => traverse_enum(enum_, mutable),
255 Data::Union(union_) => {
256 return Err(Error::new_spanned(
257 union_.union_token,
258 "unions are not supported",
259 ));
260 }
261 }?;
262
263 let impl_trait = Ident::new(
264 if mutable {
265 "TraversableMut"
266 } else {
267 "Traversable"
268 },
269 Span::call_site(),
270 );
271
272 let method = Ident::new(
273 if mutable { "traverse_mut" } else { "traverse" },
274 Span::call_site(),
275 );
276
277 let mut_modifier = if mutable {
278 Some(Mut(Span::call_site()))
279 } else {
280 None
281 };
282
283 Ok(quote! {
284 impl #impl_generics #crate_name::#impl_trait for #name #ty_generics #where_clause {
285 fn #method<V: #crate_name::#visitor>(
286 & #mut_modifier self,
287 visitor: &mut V
288 ) -> ::core::ops::ControlFlow<V::Break> {
289 #enter_self
290 #traverse_fields
291 #leave_self
292 ::core::ops::ControlFlow::Continue(())
293 }
294 }
295 })
296}
297
298fn traverse_struct(s: DataStruct, mutable: bool) -> Result<TokenStream> {
299 s.fields
300 .into_iter()
301 .enumerate()
302 .map(|(index, field)| {
303 let member = field.ident.as_ref().map_or_else(
304 || Member::Unnamed(index.into()),
305 |ident| Member::Named(ident.clone()),
306 );
307 let mut_modifier = if mutable {
308 Some(Mut(Span::call_site()))
309 } else {
310 None
311 };
312 traverse_field("e! { & #mut_modifier self.#member }, field, mutable)
313 })
314 .collect()
315}
316
317fn traverse_enum(e: DataEnum, mutable: bool) -> Result<TokenStream> {
318 let variants = e
319 .variants
320 .into_iter()
321 .map(|x| traverse_variant(x, mutable))
322 .collect::<Result<TokenStream>>()?;
323 Ok(quote! {
324 match self {
325 #variants
326 _ => {}
327 }
328 })
329}
330
331fn traverse_variant(v: Variant, mutable: bool) -> Result<TokenStream> {
332 let mut params = Params::from_attrs(v.attrs, "traverse")?;
333 params.validate(&["skip"])?;
334 if params.param("skip")?.map(Param::unit).is_some() {
335 return Ok(TokenStream::new());
336 }
337 let name = v.ident;
338 let destructuring = destructure_fields(v.fields.clone())?;
339 let fields = v
340 .fields
341 .into_iter()
342 .enumerate()
343 .map(|(index, field)| {
344 traverse_field(
345 &field
346 .ident
347 .clone()
348 .unwrap_or_else(|| Ident::new(&format!("i{}", index), Span::call_site()))
349 .to_token_stream(),
350 field,
351 mutable,
352 )
353 })
354 .collect::<Result<TokenStream>>()?;
355 Ok(quote! {
356 Self::#name #destructuring => {
357 #fields
358 }
359 })
360}
361
362fn destructure_fields(fields: Fields) -> Result<TokenStream> {
363 Ok(match fields {
364 Fields::Named(fields) => {
365 let field_list = fields
366 .named
367 .into_iter()
368 .map(|field| {
369 let mut params = Params::from_attrs(field.attrs, "traverse")?;
370 let field_name = field.ident.unwrap();
371 Ok(if params.param("skip")?.map(Param::unit).is_some() {
372 quote! { #field_name: _ }
373 } else {
374 field_name.into_token_stream()
375 })
376 })
377 .collect::<Result<Vec<TokenStream>>>()?;
378 quote! {
379 { #( #field_list ),* }
380 }
381 }
382 Fields::Unnamed(fields) => {
383 let field_list = fields
384 .unnamed
385 .into_iter()
386 .enumerate()
387 .map(|(index, field)| {
388 let mut params = Params::from_attrs(field.attrs, "traverse")?;
389 Ok(if params.param("skip")?.map(Param::unit).is_some() {
390 quote! { _ }
391 } else {
392 Ident::new(&format!("i{index}",), Span::call_site()).into_token_stream()
393 })
394 })
395 .collect::<Result<Vec<TokenStream>>>()?;
396 quote! {
397 ( #( #field_list ),* )
398 }
399 }
400 Fields::Unit => TokenStream::new(),
401 })
402}
403
404fn traverse_field(value: &TokenStream, field: Field, mutable: bool) -> Result<TokenStream> {
405 let mut params = Params::from_attrs(field.attrs, "traverse")?;
406 params.validate(&["skip", "with"])?;
407
408 if params.param("skip")?.map(Param::unit).is_some() {
409 return Ok(TokenStream::new());
410 }
411
412 let crate_name = resolve_crate_name();
413
414 match params.param("with")? {
415 None => Ok(if mutable {
416 quote! { #crate_name::TraversableMut::traverse_mut(#value, visitor)?; }
417 } else {
418 quote! { #crate_name::Traversable::traverse(#value, visitor)?; }
419 }),
420 Some(traverse_fn) => {
421 let traverse_fn = traverse_fn.string_literal()?.parse::<Path>()?;
422 Ok(quote! {
423 #traverse_fn(#value, visitor)?;
424 })
425 }
426 }
427}