type_weave_derive/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, quote_spanned};
3use syn::{
4    parse_macro_input, parse_quote, spanned::Spanned, Data, DataStruct, DeriveInput, Error, Fields,
5    FieldsNamed, FieldsUnnamed, GenericParam, Generics, Ident, Index, Result,
6};
7
8macro_rules! unwrap {
9    ($result:expr) => {
10        match $result {
11            Ok(value) => value,
12            Err(err) => return err.into_compile_error().into(),
13        }
14    };
15}
16
17/// Implement `Weave` for a struct whose fields all implement the trait.
18#[proc_macro_derive(Weave)]
19pub fn weave(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
20    let input = parse_macro_input!(input as DeriveInput);
21
22    let name = input.ident;
23    let data = unwrap!(struct_data(input.data));
24
25    let generics = add_bounds(input.generics);
26    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
27
28    let other = Ident::new("other", Span::call_site());
29    let over = producer("over", &other, &data.fields);
30    let under = producer("under", &other, &data.fields);
31
32    quote! {
33        impl #impl_generics type_weave::Weave for #name #ty_generics #where_clause {
34            fn over(self, #other: Self) -> Self {
35                #over
36            }
37
38            fn under(self, #other: Self) -> Self {
39                #under
40            }
41        }
42    }
43    .into()
44}
45
46fn struct_data(data: Data) -> Result<DataStruct> {
47    let err = "Weave derive macro only supports structs";
48    match data {
49        Data::Struct(data) => Ok(data),
50        Data::Enum(data) => Err(Error::new(data.enum_token.span(), err)),
51        Data::Union(data) => Err(Error::new(data.union_token.span(), err)),
52    }
53}
54
55fn producer(method: &str, other: &Ident, fields: &Fields) -> TokenStream {
56    let method = Ident::new(method, Span::call_site());
57    match fields {
58        Fields::Unit => quote! { Self },
59        Fields::Named(FieldsNamed { named, .. }) => {
60            let fields: Vec<_> = named
61                .iter()
62                .map(|field| {
63                    let ident = field.ident.as_ref().unwrap();
64                    quote_spanned! {field.span()=>
65                        #ident: type_weave::Weave::#method(self.#ident, #other.#ident)
66                    }
67                })
68                .collect();
69            quote! { Self { #(#fields,)* } }
70        }
71        Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
72            let fields: Vec<_> = unnamed
73                .iter()
74                .enumerate()
75                .map(|(i, field)| {
76                    let index = Index {
77                        index: i as u32,
78                        span: field.span(),
79                    };
80                    quote_spanned! {field.span()=>
81                        type_weave::Weave::#method(self.#index, #other.#index)
82                    }
83                })
84                .collect();
85            quote! { Self(#(#fields),*) }
86        }
87    }
88}
89
90fn add_bounds(mut generics: Generics) -> Generics {
91    for param in &mut generics.params {
92        if let GenericParam::Type(ty) = param {
93            ty.bounds.push(parse_quote!(type_weave::Weave));
94        }
95    }
96    generics
97}