1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use either::Either;
use proc_macro2::{Ident, Punct, Spacing, Span, TokenStream, TokenTree};
use quote::{quote, ToTokens};
use syn::{parse_macro_input, punctuated::Punctuated, DeriveInput, Field, Token, Variant};

fn join_tokens(
    exprs: impl ExactSizeIterator<Item = impl ToTokens>,
    sep: impl ToTokens,
) -> TokenStream {
    let expr_count = exprs.len();
    let mut out_tokens = TokenStream::new();
    for (i, expr) in exprs.enumerate() {
        expr.to_tokens(&mut out_tokens);
        if expr_count != i + 1 {
            sep.to_tokens(&mut out_tokens);
        }
    }

    out_tokens
}

fn gen_unnamed_ident(i: usize) -> Ident {
    Ident::new(&format!("var_field_{i}"), Span::call_site())
}

fn gen_named_exprs<'a>(
    named_fields: impl ExactSizeIterator<Item = &'a Field> + 'a,
    include_self: bool,
) -> impl ExactSizeIterator<Item = TokenStream> + 'a {
    if named_fields.len() == 0 {
        return Either::Left(std::iter::once(quote!(0_usize)));
    }

    Either::Right(named_fields.map(|field| &field.ident).map(move |ident| {
        if include_self {
            quote!(::typesize::TypeSize::extra_size(&self.#ident))
        } else {
            quote!(::typesize::TypeSize::extra_size(#ident))
        }
    }))
}

fn gen_unnamed_exprs(
    field_count: usize,
    include_self: bool,
) -> impl ExactSizeIterator<Item = TokenStream> {
    if field_count == 0 {
        return Either::Left(std::iter::once(quote!(0_usize)));
    };

    Either::Right((0..field_count).map(move |i| {
        if include_self {
            let ident = syn::Index::from(i);
            quote!(::typesize::TypeSize::extra_size(&self.#ident))
        } else {
            let ident = gen_unnamed_ident(i);
            quote!(::typesize::TypeSize::extra_size(#ident))
        }
    }))
}

fn gen_sum_exprs(exprs: impl ExactSizeIterator<Item = impl ToTokens>) -> TokenStream {
    join_tokens(exprs, Punct::new('+', Spacing::Alone))
}

fn gen_comma_exprs(exprs: impl ExactSizeIterator<Item = impl ToTokens>) -> TokenStream {
    join_tokens(exprs, Punct::new(',', Spacing::Alone))
}

fn gen_struct(data: &syn::Fields, include_self: bool) -> TokenStream {
    match data {
        syn::Fields::Named(fields) => {
            gen_sum_exprs(gen_named_exprs(fields.named.iter(), include_self))
        }
        syn::Fields::Unnamed(fields) => {
            gen_sum_exprs(gen_unnamed_exprs(fields.unnamed.len(), include_self))
        }
        syn::Fields::Unit => quote!(0_usize),
    }
}

fn get_named_idents(fields: &Punctuated<Field, Token![,]>) -> TokenStream {
    let idents = fields.iter().map(|field| field.ident.as_ref().unwrap());
    join_tokens(idents, TokenTree::Punct(Punct::new(',', Spacing::Alone)))
}

fn gen_unnamed_idents(field_count: usize) -> TokenStream {
    let idents = (0..field_count).map(gen_unnamed_ident);
    gen_comma_exprs(idents)
}

fn gen_enum(variants: impl Iterator<Item = Variant>) -> TokenStream {
    let match_arms: TokenStream = variants
        .map(|variant| {
            let variant_name = &variant.ident;
            let variant_body = gen_struct(&variant.fields, false);
            let variant_pattern = match &variant.fields {
                syn::Fields::Named(fields) => {
                    let field_names = get_named_idents(&fields.named);
                    quote!({#field_names})
                }
                syn::Fields::Unnamed(fields) => {
                    let field_names = gen_unnamed_idents(fields.unnamed.len());
                    quote!((#field_names))
                }
                syn::Fields::Unit => TokenStream::new(),
            };

            quote!(Self::#variant_name #variant_pattern => #variant_body,)
        })
        .collect();

    quote!(
        match self {
            #match_arms
        }
    )
}

#[proc_macro_derive(TypeSize)]
pub fn typesize_derive(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let DeriveInput {
        attrs: _,
        vis: _,
        ident,
        generics,
        data,
    } = parse_macro_input!(tokens as DeriveInput);

    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
    let body = match data {
        syn::Data::Struct(data) => gen_struct(&data.fields, true),
        syn::Data::Enum(data) => gen_enum(data.variants.into_iter()),
        syn::Data::Union(_) => panic!("Unions are unsupported for typesize derive!"),
    };

    let output = quote! {
        #[automatically_derived]
        impl #impl_generics ::typesize::TypeSize for #ident #ty_generics #where_clause {
            fn extra_size(&self) -> usize {
                #body
            }
        }
    };

    output.into()
}