visit_rs_derive/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro2::{Span, TokenStream};
4use quote::quote;
5use syn::{
6    parse_quote, DataStruct, DeriveInput, Fields, Ident, Lit, Meta, Path, WhereClause,
7    WherePredicate,
8};
9
10fn get_rename_attribute(ast: &DeriveInput) -> Option<String> {
11    for attr in &ast.attrs {
12        // Check for #[visit(rename = "...")]
13        if attr.path().is_ident("visit") {
14            if let Ok(meta_list) = attr.meta.require_list() {
15                if let Ok(Meta::NameValue(nv)) = syn::parse2::<Meta>(meta_list.tokens.clone()) {
16                    if nv.path.is_ident("rename") {
17                        if let syn::Expr::Lit(lit) = &nv.value {
18                            if let Lit::Str(s) = &lit.lit {
19                                return Some(s.value());
20                            }
21                        }
22                    }
23                }
24            }
25        }
26        // Check for #[serde(rename = "...")]
27        if attr.path().is_ident("serde") {
28            if let Ok(meta_list) = attr.meta.require_list() {
29                if let Ok(Meta::NameValue(nv)) = syn::parse2::<Meta>(meta_list.tokens.clone()) {
30                    if nv.path.is_ident("rename") {
31                        if let syn::Expr::Lit(lit) = &nv.value {
32                            if let Lit::Str(s) = &lit.lit {
33                                return Some(s.value());
34                            }
35                        }
36                    }
37                }
38            }
39        }
40    }
41    None
42}
43
44fn make_impl(
45    input: &DeriveInput,
46    fields: &Fields,
47    trait_path_fields: &Path,
48    trait_path: &Path,
49    named: Option<&Path>,
50    sync: bool,
51    is_static: bool,
52) -> TokenStream {
53    let ident = &input.ident;
54
55    let (_, ty_generics, _) = &input.generics.split_for_impl();
56
57    let mut generics = input.generics.clone();
58
59    generics.params.push(syn::parse_quote! { __visit_rs__V });
60
61    let predicates = &mut generics
62        .where_clause
63        .get_or_insert(WhereClause {
64            predicates: Default::default(),
65            where_token: Default::default(),
66        })
67        .predicates;
68
69    predicates.push(syn::parse_quote! { __visit_rs__V: visit_rs::Visitor });
70    if sync {
71        predicates.extend(fields.iter().map(|f| &f.ty).map(|t| -> WherePredicate {
72            parse_quote! { #t: Sync }
73        }));
74    }
75
76    let mut ty_set = HashSet::new();
77    for (_, field) in field_iter(fields) {
78        let ty = &field.ty;
79        if !ty_set.insert(ty) {
80            continue;
81        }
82        if let Some(named) = named {
83            if is_static {
84                predicates.push(
85                    syn::parse_quote! { for<'__visit_rs__named> #named <'__visit_rs__named, visit_rs::Static<#ty>>: #trait_path<__visit_rs__V> },
86                );
87            } else {
88                predicates.push(syn::parse_quote! { for<'__visit_rs__named> #named <'__visit_rs__named, #ty>: #trait_path<__visit_rs__V> });
89            }
90        } else {
91            if is_static {
92                predicates
93                    .push(syn::parse_quote! { visit_rs::Static<#ty>: #trait_path<__visit_rs__V> });
94            } else {
95                predicates.push(syn::parse_quote! { #ty: #trait_path<__visit_rs__V> });
96            }
97        }
98    }
99
100    let (impl_generics, _, where_clause) = generics.split_for_impl();
101
102    quote! {
103        impl #impl_generics #trait_path_fields<__visit_rs__V> for #ident #ty_generics #where_clause
104    }
105}
106
107fn field_iter(fields: &Fields) -> impl Iterator<Item = (usize, &syn::Field)> {
108    fields.iter().enumerate().filter(|(_, field)| {
109        !field.attrs.iter().any(|attr| {
110            attr.path().is_ident("visit")
111                && attr.parse_args::<Ident>().map_or(false, |id| id == "skip")
112        })
113    })
114}
115
116fn field_idx_iter(fields: &Fields) -> impl Iterator<Item = TokenStream> {
117    field_iter(fields).map(|(index, field)| {
118        let field_name = &field.ident;
119        if let Some(name) = field_name {
120            quote! { #name }
121        } else {
122            let index = syn::Index::from(index);
123            quote! { #index }
124        }
125    })
126}
127
128fn field_name_idx_iter(fields: &syn::Fields) -> impl Iterator<Item = (TokenStream, TokenStream)> {
129    field_iter(fields).map(|(index, field)| {
130        let field_name = &field.ident;
131        let idx = if let Some(name) = field_name {
132            quote! { #name }
133        } else {
134            let index = syn::Index::from(index);
135            quote! { #index }
136        };
137        let name = if let Some(name) = field_name {
138            quote! { Some(stringify!(#name)) }
139        } else {
140            quote! { None }
141        };
142        (name, idx)
143    })
144}
145
146#[proc_macro_derive(VisitFields, attributes(visit))]
147pub fn derive_visit_fields_(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
148    let ast: DeriveInput = syn::parse(input).unwrap();
149
150    let syn::Data::Struct(data) = &ast.data else {
151        let span = match &ast.data {
152            syn::Data::Enum(data) => data.enum_token.span,
153            syn::Data::Union(data) => data.union_token.span,
154            _ => Span::call_site(),
155        };
156        return syn::Error::new(span, "VisitFields can only be derived for structs")
157            .to_compile_error()
158            .into();
159    };
160
161    let all_impls = match (|| {
162        Ok::<_, syn::Error>([
163            derive_struct_info(&ast, data)?,
164            derive_visit_fields(&ast, data)?,
165            derive_visit_fields_covered(&ast, data)?,
166            derive_visit_fields_async(&ast, data)?,
167            derive_visit_fields_covered_async(&ast, data)?,
168            derive_visit_fields_named(&ast, data)?,
169            derive_visit_fields_named_async(&ast, data)?,
170            derive_visit_fields_static(&ast, data)?,
171            derive_visit_fields_static_async(&ast, data)?,
172            derive_visit_fields_static_named(&ast, data)?,
173            derive_visit_fields_static_named_async(&ast, data)?,
174        ])
175    })() {
176        Ok(a) => a,
177        Err(e) => return e.to_compile_error().into(),
178    };
179
180    // panic!(
181    //     "{}",
182    proc_macro::TokenStream::from(quote! {
183        #(#all_impls)*
184    })
185    // )
186}
187
188fn derive_struct_info(ast: &DeriveInput, data: &DataStruct) -> Result<TokenStream, syn::Error> {
189    let ident = &ast.ident;
190    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
191
192    let named_fields = matches!(data.fields, Fields::Named(_));
193    let field_count = field_iter(&data.fields).count();
194
195    let name = get_rename_attribute(ast).unwrap_or_else(|| ident.to_string());
196
197    Ok(quote! {
198        impl #impl_generics visit_rs::StructInfo for #ident #ty_generics #where_clause {
199            const DATA: visit_rs::StructInfoData = visit_rs::StructInfoData {
200                name: #name,
201                named_fields: #named_fields,
202                field_count: #field_count,
203            };
204        }
205    })
206}
207
208fn derive_visit_fields(ast: &DeriveInput, data: &DataStruct) -> Result<TokenStream, syn::Error> {
209    let impl_t = make_impl(
210        &ast,
211        &data.fields,
212        &syn::parse_quote! { visit_rs::VisitFields },
213        &syn::parse_quote! { visit_rs::Visit },
214        None,
215        false,
216        false,
217    );
218
219    let visit_fields_impl = field_idx_iter(&data.fields).enumerate().map(|(num, idx)| {
220        quote! {
221            #num => {
222                pos += 1;
223                Some(visit_rs::Visit::visit(&self.#idx, visitor))
224            }
225        }
226    });
227
228    Ok(quote! {
229        #impl_t {
230            fn visit_fields<'__visit_rs__a>(
231                &'__visit_rs__a self,
232                visitor: &'__visit_rs__a mut __visit_rs__V
233            ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> {
234                std::iter::from_fn({
235                    let mut pos = 0;
236                    move || match pos {
237                        #(#visit_fields_impl)*
238                        _ => None,
239                    }
240                })
241            }
242        }
243    })
244}
245
246fn derive_visit_fields_covered(
247    ast: &DeriveInput,
248    data: &DataStruct,
249) -> Result<TokenStream, syn::Error> {
250    let impl_t = make_impl(
251        &ast,
252        &data.fields,
253        &syn::parse_quote! { visit_rs::VisitFieldsCovered },
254        &syn::parse_quote! { visit_rs::Visit },
255        Some(&syn::parse_quote! { visit_rs::Covered }),
256        false,
257        false,
258    );
259
260    let visit_fields_impl = field_idx_iter(&data.fields).enumerate().map(|(num, idx)| {
261        quote! {
262            #num => {
263                pos += 1;
264                Some(visit_rs::Visit::visit(&visit_rs::Covered(&self.#idx), visitor))
265            }
266        }
267    });
268
269    Ok(quote! {
270        #impl_t {
271            fn visit_fields_covered<'__visit_rs__a>(
272                &'__visit_rs__a self,
273                visitor: &'__visit_rs__a mut __visit_rs__V
274            ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> {
275                std::iter::from_fn({
276                    let mut pos = 0;
277                    move || match pos {
278                        #(#visit_fields_impl)*
279                        _ => None,
280                    }
281                })
282            }
283        }
284    })
285}
286
287fn derive_visit_fields_async(
288    ast: &DeriveInput,
289    data: &DataStruct,
290) -> Result<TokenStream, syn::Error> {
291    let impl_t = make_impl(
292        &ast,
293        &data.fields,
294        &syn::parse_quote! { visit_rs::VisitFieldsAsync },
295        &syn::parse_quote! { visit_rs::VisitAsync },
296        None,
297        true,
298        false,
299    );
300
301    let visit_fields_impl = field_idx_iter(&data.fields).map(|idx| {
302        quote! {
303            yield visit_rs::VisitAsync::visit_async(&self.#idx, visitor).await;
304        }
305    });
306
307    Ok(quote! {
308        #impl_t {
309            fn visit_fields_async<'__visit_rs__a>(
310                &'__visit_rs__a self,
311                visitor: &'__visit_rs__a mut __visit_rs__V,
312            ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
313            where
314                __visit_rs__V: Send,
315                <__visit_rs__V as visit_rs::Visitor>::Result: Send,
316            {
317                visit_rs::lib::async_stream::stream! {
318                    #(#visit_fields_impl)*
319                    #[allow(unreachable_code)]
320                    if false {
321                        yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
322                    }
323                }
324            }
325        }
326    })
327}
328
329fn derive_visit_fields_covered_async(
330    ast: &DeriveInput,
331    data: &DataStruct,
332) -> Result<TokenStream, syn::Error> {
333    let impl_t = make_impl(
334        &ast,
335        &data.fields,
336        &syn::parse_quote! { visit_rs::VisitFieldsCoveredAsync },
337        &syn::parse_quote! { visit_rs::VisitAsync },
338        Some(&syn::parse_quote! { visit_rs::Covered }),
339        true,
340        false,
341    );
342
343    let visit_fields_impl = field_idx_iter(&data.fields).map(|idx| {
344        quote! {
345            yield visit_rs::VisitAsync::visit_async(&visit_rs::Covered(&self.#idx), visitor).await;
346        }
347    });
348
349    Ok(quote! {
350        #impl_t {
351            fn visit_fields_covered_async<'__visit_rs__a>(
352                &'__visit_rs__a self,
353                visitor: &'__visit_rs__a mut __visit_rs__V,
354            ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
355            where
356                __visit_rs__V: Send,
357                <__visit_rs__V as visit_rs::Visitor>::Result: Send,
358            {
359                visit_rs::lib::async_stream::stream! {
360                    #(#visit_fields_impl)*
361                    #[allow(unreachable_code)]
362                    if false {
363                        yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
364                    }
365                }
366            }
367        }
368    })
369}
370
371fn derive_visit_fields_named(
372    ast: &DeriveInput,
373    data: &DataStruct,
374) -> Result<TokenStream, syn::Error> {
375    let impl_t = make_impl(
376        &ast,
377        &data.fields,
378        &syn::parse_quote! { visit_rs::VisitFieldsNamed },
379        &syn::parse_quote! { visit_rs::Visit },
380        Some(&syn::parse_quote! { visit_rs::Named }),
381        false,
382        false,
383    );
384
385    let visit_fields_named_impl =
386        field_name_idx_iter(&data.fields)
387            .enumerate()
388            .map(|(num, (name, idx))| {
389                quote! {
390                    #num => {
391                        pos += 1;
392                        Some(visit_rs::Visit::visit(&visit_rs::Named {
393                            name: #name,
394                            value: &self.#idx,
395                        }, visitor))
396                    }
397                }
398            });
399
400    Ok(quote! {
401        #impl_t {
402            fn visit_fields_named<'__visit_rs__a>(
403                &'__visit_rs__a self,
404                visitor: &'__visit_rs__a mut __visit_rs__V
405            ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + '__visit_rs__a {
406                std::iter::from_fn({
407                    let mut pos = 0;
408                    move || match pos {
409                        #(#visit_fields_named_impl)*
410                        _ => None,
411                    }
412                })
413            }
414        }
415    })
416}
417
418fn derive_visit_fields_named_async(
419    ast: &DeriveInput,
420    data: &DataStruct,
421) -> Result<TokenStream, syn::Error> {
422    let impl_t = make_impl(
423        &ast,
424        &data.fields,
425        &syn::parse_quote! { visit_rs::VisitFieldsNamedAsync },
426        &syn::parse_quote! { visit_rs::VisitAsync },
427        Some(&syn::parse_quote! { visit_rs::Named }),
428        true,
429        false,
430    );
431
432    let visit_fields_named_impl = field_name_idx_iter(&data.fields).map(|(name, idx)| {
433        quote! {
434            yield visit_rs::VisitAsync::visit_async(&visit_rs::Named {
435                name: #name,
436                value: &self.#idx,
437            }, visitor).await;
438        }
439    });
440
441    Ok(quote! {
442        #impl_t {
443            fn visit_fields_named_async<'__visit_rs__a>(
444                &'__visit_rs__a self,
445                visitor: &'__visit_rs__a mut __visit_rs__V,
446            ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
447            where
448                __visit_rs__V: Send,
449                <__visit_rs__V as visit_rs::Visitor>::Result: Send,
450            {
451                visit_rs::lib::async_stream::stream! {
452                    #(#visit_fields_named_impl)*
453                    #[allow(unreachable_code)]
454                    if false {
455                        yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
456                    }
457                }
458            }
459        }
460    })
461}
462
463fn derive_visit_fields_static(
464    ast: &DeriveInput,
465    data: &DataStruct,
466) -> Result<TokenStream, syn::Error> {
467    let impl_t = make_impl(
468        &ast,
469        &data.fields,
470        &syn::parse_quote! { visit_rs::VisitFieldsStatic },
471        &syn::parse_quote! { visit_rs::Visit },
472        None,
473        false,
474        true,
475    );
476
477    let field_types: Vec<_> = field_iter(&data.fields)
478        .map(|(_, field)| &field.ty)
479        .collect();
480    let visit_fields_impl = field_types.iter().enumerate().map(|(num, ty)| {
481        quote! {
482            #num => {
483                pos += 1;
484                Some(visit_rs::Visit::visit(&visit_rs::Static::<#ty>::new(), visitor))
485            }
486        }
487    });
488
489    Ok(quote! {
490        #impl_t {
491            fn visit_fields_static<'__visit_rs__a>(
492                visitor: &'__visit_rs__a mut __visit_rs__V
493            ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + '__visit_rs__a {
494                std::iter::from_fn({
495                    let mut pos = 0;
496                    move || match pos {
497                        #(#visit_fields_impl)*
498                        _ => None,
499                    }
500                })
501            }
502        }
503    })
504}
505
506fn derive_visit_fields_static_async(
507    ast: &DeriveInput,
508    data: &DataStruct,
509) -> Result<TokenStream, syn::Error> {
510    let impl_t = make_impl(
511        &ast,
512        &data.fields,
513        &syn::parse_quote! { visit_rs::VisitFieldsStaticAsync },
514        &syn::parse_quote! { visit_rs::VisitAsync },
515        None,
516        true,
517        true,
518    );
519
520    let field_types: Vec<_> = field_iter(&data.fields)
521        .map(|(_, field)| &field.ty)
522        .collect();
523    let visit_fields_impl = field_types.iter().map(|ty| {
524        quote! {
525            yield visit_rs::VisitAsync::visit_async(&visit_rs::Static::<#ty>::new(), visitor).await;
526        }
527    });
528
529    Ok(quote! {
530        #impl_t {
531            fn visit_fields_static_async<'__visit_rs__a>(
532                visitor: &'__visit_rs__a mut __visit_rs__V,
533            ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
534            where
535                __visit_rs__V: Send,
536                <__visit_rs__V as visit_rs::Visitor>::Result: Send,
537            {
538                visit_rs::lib::async_stream::stream! {
539                    #(#visit_fields_impl)*
540                    #[allow(unreachable_code)]
541                    if false {
542                        yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
543                    }
544                }
545            }
546        }
547    })
548}
549
550fn derive_visit_fields_static_named(
551    ast: &DeriveInput,
552    data: &DataStruct,
553) -> Result<TokenStream, syn::Error> {
554    let impl_t = make_impl(
555        &ast,
556        &data.fields,
557        &syn::parse_quote! { visit_rs::VisitFieldsStaticNamed },
558        &syn::parse_quote! { visit_rs::Visit },
559        Some(&syn::parse_quote! { visit_rs::Named }),
560        false,
561        true,
562    );
563
564    let field_name_type_iter = field_iter(&data.fields).map(|(_, field)| {
565        let field_name = &field.ident;
566        let ty = &field.ty;
567        let name = if let Some(name) = field_name {
568            quote! { Some(stringify!(#name)) }
569        } else {
570            quote! { None }
571        };
572        (name, ty)
573    });
574
575    let visit_fields_named_impl = field_name_type_iter.enumerate().map(|(num, (name, ty))| {
576        quote! {
577            #num => {
578                pos += 1;
579                {
580                    static __VISIT_RS_STATIC: visit_rs::Static<()> = visit_rs::Static::new();
581                    let named = visit_rs::Named {
582                        name: #name,
583                        value: unsafe {
584                            // SAFETY: Static<T> is zero-sized and contains only PhantomData,
585                            // so transmuting from &Static<()> to &Static<#ty> is safe
586                            &*(&__VISIT_RS_STATIC as *const visit_rs::Static<()> as *const visit_rs::Static<#ty>)
587                        },
588                    };
589                    Some(visit_rs::Visit::visit(&named, visitor))
590                }
591            }
592        }
593    });
594
595    Ok(quote! {
596        #impl_t {
597            fn visit_fields_static_named<'__visit_rs__a>(
598                visitor: &'__visit_rs__a mut __visit_rs__V
599            ) -> impl Iterator<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + '__visit_rs__a {
600                std::iter::from_fn({
601                    let mut pos = 0;
602                    move || match pos {
603                        #(#visit_fields_named_impl)*
604                        _ => None,
605                    }
606                })
607            }
608        }
609    })
610}
611
612fn derive_visit_fields_static_named_async(
613    ast: &DeriveInput,
614    data: &DataStruct,
615) -> Result<TokenStream, syn::Error> {
616    let impl_t = make_impl(
617        &ast,
618        &data.fields,
619        &syn::parse_quote! { visit_rs::VisitFieldsStaticNamedAsync },
620        &syn::parse_quote! { visit_rs::VisitAsync },
621        Some(&syn::parse_quote! { visit_rs::Named }),
622        true,
623        true,
624    );
625
626    let field_name_type_iter = field_iter(&data.fields).map(|(_, field)| {
627        let field_name = &field.ident;
628        let ty = &field.ty;
629        let name = if let Some(name) = field_name {
630            quote! { Some(stringify!(#name)) }
631        } else {
632            quote! { None }
633        };
634        (name, ty)
635    });
636
637    let visit_fields_named_impl = field_name_type_iter.map(|(name, ty)| {
638        quote! {
639            {
640                static __VISIT_RS_STATIC: visit_rs::Static<()> = visit_rs::Static::new();
641                let named = visit_rs::Named {
642                    name: #name,
643                    value: unsafe {
644                        // SAFETY: Static<T> is zero-sized and contains only PhantomData,
645                        // so transmuting from &Static<()> to &Static<#ty> is safe
646                        &*(&__VISIT_RS_STATIC as *const visit_rs::Static<()> as *const visit_rs::Static<#ty>)
647                    },
648                };
649                yield visit_rs::VisitAsync::visit_async(&named, visitor).await;
650            }
651        }
652    });
653
654    Ok(quote! {
655        #impl_t {
656            fn visit_fields_static_named_async<'__visit_rs__a>(
657                visitor: &'__visit_rs__a mut __visit_rs__V,
658            ) -> impl visit_rs::lib::futures::Stream<Item = <__visit_rs__V as visit_rs::Visitor>::Result> + Send + '__visit_rs__a
659            where
660                __visit_rs__V: Send,
661                <__visit_rs__V as visit_rs::Visitor>::Result: Send,
662            {
663                visit_rs::lib::async_stream::stream! {
664                    #(#visit_fields_named_impl)*
665                    #[allow(unreachable_code)]
666                    if false {
667                        yield unreachable!() as <__visit_rs__V as visit_rs::Visitor>::Result
668                    }
669                }
670            }
671        }
672    })
673}
674
675mod helpers;
676mod enum_variants;
677
678#[proc_macro_derive(VisitVariants, attributes(visit))]
679pub fn derive_visit_variants(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
680    let ast: DeriveInput = syn::parse(input).unwrap();
681
682    let syn::Data::Enum(data) = &ast.data else {
683        return syn::Error::new_spanned(&ast.ident, "VisitVariants can only be used on enums")
684            .to_compile_error()
685            .into();
686    };
687
688    match enum_variants::derive_all_variant_traits(&ast, data) {
689        Ok(tokens) => tokens.into(),
690        Err(e) => e.to_compile_error().into(),
691    }
692}