traversable_derive/
lib.rs

1// Copyright 2025 FastLabs Developers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::collections::hash_map::Entry;
17use std::iter::IntoIterator;
18
19use proc_macro2::Span;
20use proc_macro2::TokenStream;
21use quote::ToTokens;
22use quote::quote;
23use syn::Attribute;
24use syn::Data;
25use syn::DataEnum;
26use syn::DataStruct;
27use syn::DeriveInput;
28use syn::Error;
29use syn::Expr;
30use syn::Field;
31use syn::Fields;
32use syn::Ident;
33use syn::Lit;
34use syn::LitStr;
35use syn::Member;
36use syn::Meta;
37use syn::MetaList;
38use syn::Path;
39use syn::Result;
40use syn::Token;
41use syn::Variant;
42use syn::parse_macro_input;
43use syn::parse_quote;
44use syn::punctuated::Punctuated;
45use syn::spanned::Spanned;
46use syn::token::Mut;
47
48#[proc_macro_derive(Traversable, attributes(traverse))]
49pub fn derive_traversable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
50    expand_with(input, |stream| impl_traversable(stream, false))
51}
52
53#[proc_macro_derive(TraversableMut, attributes(traverse))]
54pub fn derive_traversable_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
55    expand_with(input, |stream| impl_traversable(stream, true))
56}
57
58fn expand_with(
59    input: proc_macro::TokenStream,
60    handler: impl Fn(DeriveInput) -> Result<TokenStream>,
61) -> proc_macro::TokenStream {
62    let input = parse_macro_input!(input as DeriveInput);
63    handler(input)
64        .unwrap_or_else(|error| error.to_compile_error())
65        .into()
66}
67
68fn extract_meta(attrs: Vec<Attribute>, attr_name: &str) -> Result<Option<Meta>> {
69    let macro_attrs = attrs
70        .into_iter()
71        .filter(|attr| attr.path().is_ident(attr_name))
72        .collect::<Vec<Attribute>>();
73
74    if let Some(second) = macro_attrs.get(2) {
75        return Err(Error::new_spanned(second, "duplicate attribute"));
76    }
77
78    macro_attrs
79        .first()
80        .map(|attr| Ok(attr.meta.clone()))
81        .transpose()
82}
83
84#[derive(Default)]
85struct Params(HashMap<Path, Meta>);
86
87impl Params {
88    fn from_attrs(attrs: Vec<Attribute>, attr_name: &str) -> Result<Self> {
89        Ok(extract_meta(attrs, attr_name)?
90            .map(|meta| {
91                if let Meta::List(meta_list) = meta {
92                    Self::from_meta_list(meta_list)
93                } else {
94                    Err(Error::new_spanned(meta, "invalid attribute"))
95                }
96            })
97            .transpose()?
98            .unwrap_or_default())
99    }
100
101    fn from_meta_list(meta_list: MetaList) -> Result<Self> {
102        let mut params = HashMap::new();
103        let nested = meta_list.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
104        for meta in nested {
105            let path = meta.path();
106            let entry = params.entry(path.clone());
107            if matches!(entry, Entry::Occupied(_)) {
108                return Err(Error::new_spanned(path, "duplicate parameter"));
109            }
110            entry.or_insert(meta);
111        }
112        Ok(Self(params))
113    }
114
115    fn validate(&self, allowed_params: &[&str]) -> Result<()> {
116        for path in self.0.keys() {
117            if !allowed_params
118                .iter()
119                .any(|allowed_param| path.is_ident(allowed_param))
120            {
121                return Err(Error::new_spanned(
122                    path,
123                    format!(
124                        "unknown parameter, supported: {}",
125                        allowed_params.join(", ")
126                    ),
127                ));
128            }
129        }
130        Ok(())
131    }
132
133    fn param(&mut self, name: &str) -> Result<Option<Param>> {
134        self.0
135            .remove(&Ident::new(name, Span::call_site()).into())
136            .map(Param::from_meta)
137            .transpose()
138    }
139}
140
141impl Iterator for Params {
142    type Item = Result<Param>;
143    fn next(&mut self) -> Option<Self::Item> {
144        self.0
145            .keys()
146            .next()
147            .cloned()
148            .map(|path| Param::from_meta(self.0.remove(&path).unwrap()))
149    }
150}
151
152enum Param {
153    Unit(Span),
154    StringLiteral(Span, LitStr),
155    NestedParams(Span),
156}
157
158impl Param {
159    fn from_meta(meta: Meta) -> Result<Self> {
160        let span = meta.span();
161        match meta {
162            Meta::Path(_) => Ok(Param::Unit(span)),
163            Meta::List(_) => Ok(Param::NestedParams(span)),
164            Meta::NameValue(name_value) => {
165                if let Expr::Lit(expr_lit) = &name_value.value {
166                    if let Lit::Str(lit_str) = &expr_lit.lit {
167                        Ok(Param::StringLiteral(span, lit_str.clone()))
168                    } else {
169                        Err(Error::new_spanned(name_value, "invalid parameter"))
170                    }
171                } else {
172                    Err(Error::new_spanned(name_value, "invalid parameter"))
173                }
174            }
175        }
176    }
177
178    fn span(&self) -> Span {
179        match self {
180            Self::Unit(span) | Self::StringLiteral(span, _) | Self::NestedParams(span) => *span,
181        }
182    }
183
184    fn unit(self) -> Result<()> {
185        if let Self::Unit(_) = self {
186            Ok(())
187        } else {
188            Err(Error::new(self.span(), "invalid parameter"))
189        }
190    }
191
192    fn string_literal(self) -> Result<LitStr> {
193        if let Self::StringLiteral(_, lit_str) = self {
194            Ok(lit_str)
195        } else {
196            Err(Error::new(self.span(), "invalid parameter"))
197        }
198    }
199}
200
201#[inline(always)]
202fn resolve_crate_name() -> Path {
203    parse_quote!(::traversable)
204}
205
206fn impl_traversable(input: DeriveInput, mutable: bool) -> Result<TokenStream> {
207    let mut params = Params::from_attrs(input.attrs, "traverse")?;
208    params.validate(&["skip"])?;
209
210    let skip_visit_self = params
211        .param("skip")?
212        .map(Param::unit)
213        .transpose()?
214        .is_some();
215
216    let name = input.ident;
217    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
218
219    let visitor = Ident::new(
220        if mutable { "VisitorMut" } else { "Visitor" },
221        Span::call_site(),
222    );
223
224    let enter_method = Ident::new(
225        if mutable { "enter_mut" } else { "enter" },
226        Span::call_site(),
227    );
228
229    let leave_method = Ident::new(
230        if mutable { "leave_mut" } else { "leave" },
231        Span::call_site(),
232    );
233
234    let crate_name = resolve_crate_name();
235
236    let enter_self = if skip_visit_self {
237        None
238    } else {
239        Some(quote! {
240            #crate_name::#visitor::#enter_method(visitor, self)?;
241        })
242    };
243
244    let leave_self = if skip_visit_self {
245        None
246    } else {
247        Some(quote! {
248            #crate_name::#visitor::#leave_method(visitor, self)?;
249        })
250    };
251
252    let traverse_fields = match input.data {
253        Data::Struct(struct_) => traverse_struct(struct_, mutable),
254        Data::Enum(enum_) => traverse_enum(enum_, mutable),
255        Data::Union(union_) => {
256            return Err(Error::new_spanned(
257                union_.union_token,
258                "unions are not supported",
259            ));
260        }
261    }?;
262
263    let impl_trait = Ident::new(
264        if mutable {
265            "TraversableMut"
266        } else {
267            "Traversable"
268        },
269        Span::call_site(),
270    );
271
272    let method = Ident::new(
273        if mutable { "traverse_mut" } else { "traverse" },
274        Span::call_site(),
275    );
276
277    let mut_modifier = if mutable {
278        Some(Mut(Span::call_site()))
279    } else {
280        None
281    };
282
283    Ok(quote! {
284        impl #impl_generics #crate_name::#impl_trait for #name #ty_generics #where_clause {
285            fn #method<V: #crate_name::#visitor>(
286                & #mut_modifier self,
287                visitor: &mut V
288            ) -> ::core::ops::ControlFlow<V::Break> {
289                #enter_self
290                #traverse_fields
291                #leave_self
292                ::core::ops::ControlFlow::Continue(())
293            }
294        }
295    })
296}
297
298fn traverse_struct(s: DataStruct, mutable: bool) -> Result<TokenStream> {
299    s.fields
300        .into_iter()
301        .enumerate()
302        .map(|(index, field)| {
303            let member = field.ident.as_ref().map_or_else(
304                || Member::Unnamed(index.into()),
305                |ident| Member::Named(ident.clone()),
306            );
307            let mut_modifier = if mutable {
308                Some(Mut(Span::call_site()))
309            } else {
310                None
311            };
312            traverse_field(&quote! { & #mut_modifier self.#member }, field, mutable)
313        })
314        .collect()
315}
316
317fn traverse_enum(e: DataEnum, mutable: bool) -> Result<TokenStream> {
318    let variants = e
319        .variants
320        .into_iter()
321        .map(|x| traverse_variant(x, mutable))
322        .collect::<Result<TokenStream>>()?;
323    Ok(quote! {
324        match self {
325            #variants
326            _ => {}
327        }
328    })
329}
330
331fn traverse_variant(v: Variant, mutable: bool) -> Result<TokenStream> {
332    let mut params = Params::from_attrs(v.attrs, "traverse")?;
333    params.validate(&["skip"])?;
334    if params.param("skip")?.map(Param::unit).is_some() {
335        return Ok(TokenStream::new());
336    }
337    let name = v.ident;
338    let destructuring = destructure_fields(v.fields.clone())?;
339    let fields = v
340        .fields
341        .into_iter()
342        .enumerate()
343        .map(|(index, field)| {
344            traverse_field(
345                &field
346                    .ident
347                    .clone()
348                    .unwrap_or_else(|| Ident::new(&format!("i{}", index), Span::call_site()))
349                    .to_token_stream(),
350                field,
351                mutable,
352            )
353        })
354        .collect::<Result<TokenStream>>()?;
355    Ok(quote! {
356        Self::#name #destructuring => {
357            #fields
358        }
359    })
360}
361
362fn destructure_fields(fields: Fields) -> Result<TokenStream> {
363    Ok(match fields {
364        Fields::Named(fields) => {
365            let field_list = fields
366                .named
367                .into_iter()
368                .map(|field| {
369                    let mut params = Params::from_attrs(field.attrs, "traverse")?;
370                    let field_name = field.ident.unwrap();
371                    Ok(if params.param("skip")?.map(Param::unit).is_some() {
372                        quote! { #field_name: _ }
373                    } else {
374                        field_name.into_token_stream()
375                    })
376                })
377                .collect::<Result<Vec<TokenStream>>>()?;
378            quote! {
379                { #( #field_list ),* }
380            }
381        }
382        Fields::Unnamed(fields) => {
383            let field_list = fields
384                .unnamed
385                .into_iter()
386                .enumerate()
387                .map(|(index, field)| {
388                    let mut params = Params::from_attrs(field.attrs, "traverse")?;
389                    Ok(if params.param("skip")?.map(Param::unit).is_some() {
390                        quote! { _ }
391                    } else {
392                        Ident::new(&format!("i{index}",), Span::call_site()).into_token_stream()
393                    })
394                })
395                .collect::<Result<Vec<TokenStream>>>()?;
396            quote! {
397                ( #( #field_list ),* )
398            }
399        }
400        Fields::Unit => TokenStream::new(),
401    })
402}
403
404fn traverse_field(value: &TokenStream, field: Field, mutable: bool) -> Result<TokenStream> {
405    let mut params = Params::from_attrs(field.attrs, "traverse")?;
406    params.validate(&["skip", "with"])?;
407
408    if params.param("skip")?.map(Param::unit).is_some() {
409        return Ok(TokenStream::new());
410    }
411
412    let crate_name = resolve_crate_name();
413
414    match params.param("with")? {
415        None => Ok(if mutable {
416            quote! { #crate_name::TraversableMut::traverse_mut(#value, visitor)?; }
417        } else {
418            quote! { #crate_name::Traversable::traverse(#value, visitor)?; }
419        }),
420        Some(traverse_fn) => {
421            let traverse_fn = traverse_fn.string_literal()?.parse::<Path>()?;
422            Ok(quote! {
423                #traverse_fn(#value, visitor)?;
424            })
425        }
426    }
427}