typesize_derive/
lib.rs

1use proc_macro2::{Ident, Punct, Spacing, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, DeriveInput, Field};
4
5mod r#enum;
6mod r#struct;
7
8use r#enum::gen_enum;
9use r#struct::gen_struct;
10
11#[derive(Clone)]
12enum FieldConfig {
13    Default,
14    Skip,
15    With(syn::Path),
16}
17
18#[derive(Clone, Copy)]
19pub(crate) enum PassMode {
20    AsIs,
21    InsertRef,
22    Packed,
23}
24
25fn gen_call_with_arg(
26    func_name: &TokenStream,
27    arg: &TokenStream,
28    pass_mode: PassMode,
29) -> TokenStream {
30    match pass_mode {
31        PassMode::AsIs => quote!(#func_name(#arg)),
32        PassMode::InsertRef => quote!(#func_name(&#arg)),
33        PassMode::Packed => {
34            quote!(({
35                let __typesize_internal_temp = #arg;
36                #func_name(&__typesize_internal_temp)
37            }))
38        }
39    }
40}
41
42fn join_tokens(
43    exprs: impl ExactSizeIterator<Item = impl ToTokens>,
44    sep: impl ToTokens,
45) -> TokenStream {
46    let expr_count = exprs.len();
47    let mut out_tokens = TokenStream::new();
48    for (i, expr) in exprs.enumerate() {
49        expr.to_tokens(&mut out_tokens);
50        if expr_count != i + 1 {
51            sep.to_tokens(&mut out_tokens);
52        }
53    }
54
55    out_tokens
56}
57
58fn try_join_tokens(
59    exprs: impl ExactSizeIterator<Item = syn::Result<impl ToTokens>>,
60    sep: impl ToTokens,
61) -> syn::Result<TokenStream> {
62    let expr_count = exprs.len();
63    let mut out_tokens = TokenStream::new();
64    for (i, expr) in exprs.enumerate() {
65        expr?.to_tokens(&mut out_tokens);
66        if expr_count != i + 1 {
67            sep.to_tokens(&mut out_tokens);
68        }
69    }
70
71    Ok(out_tokens)
72}
73
74fn gen_named_exprs<'a>(
75    named_fields: syn::punctuated::Iter<'a, Field>,
76    transform_named: impl Fn(&'a Ident) -> TokenStream + 'a,
77    common_body: impl Fn(TokenStream, TokenStream, FieldConfig) -> TokenStream + 'a,
78) -> Option<impl ExactSizeIterator<Item = syn::Result<TokenStream>> + 'a> {
79    if named_fields.len() == 0 {
80        return None;
81    }
82
83    Some(named_fields.map(move |field| {
84        let ident = field.ident.as_ref().unwrap();
85        let field_config = get_field_config(&field.attrs)?;
86        Ok(common_body(
87            transform_named(ident),
88            quote!(#ident),
89            field_config,
90        ))
91    }))
92}
93
94fn gen_unnamed_exprs<'a>(
95    unnamed_fields: syn::punctuated::Iter<'a, Field>,
96    transform_unnamed: impl Fn(usize) -> TokenStream + 'a,
97    common_body: impl Fn(TokenStream, TokenStream, FieldConfig) -> TokenStream + 'a,
98) -> Option<impl ExactSizeIterator<Item = syn::Result<TokenStream>> + 'a> {
99    if unnamed_fields.len() == 0 {
100        return None;
101    };
102
103    let enumerated_iter = unnamed_fields.enumerate();
104    Some(enumerated_iter.map(move |(i, field)| {
105        let field_config = get_field_config(&field.attrs)?;
106        Ok(common_body(transform_unnamed(i), quote!(#i), field_config))
107    }))
108}
109
110fn for_each_field<'a>(
111    fields: &'a syn::Fields,
112    join_with: Punct,
113    transform_named: impl Fn(&'a Ident) -> TokenStream + 'a,
114    transform_unnamed: impl Fn(usize) -> TokenStream + 'a,
115    common_body: impl Fn(TokenStream, TokenStream, FieldConfig) -> TokenStream + 'a,
116) -> Option<syn::Result<TokenStream>> {
117    match fields {
118        syn::Fields::Named(fields) => Some(try_join_tokens(
119            gen_named_exprs(fields.named.iter(), transform_named, common_body)?,
120            join_with,
121        )),
122        syn::Fields::Unnamed(fields) => Some(try_join_tokens(
123            gen_unnamed_exprs(fields.unnamed.iter(), transform_unnamed, common_body)?,
124            join_with,
125        )),
126        syn::Fields::Unit => None,
127    }
128}
129
130fn extra_details_visit_fields<'a>(
131    fields: &'a syn::Fields,
132    transform_named: impl Fn(&'a Ident) -> TokenStream + 'a,
133    transform_unnamed: impl Fn(usize) -> TokenStream + 'a,
134    pass_mode: PassMode,
135) -> syn::Result<TokenStream> {
136    for_each_field(
137        fields,
138        Punct::new('+', Spacing::Alone),
139        transform_named,
140        transform_unnamed,
141        move |ident, _name, config| match config {
142            FieldConfig::Skip => quote!(0),
143            FieldConfig::Default => {
144                gen_call_with_arg(&quote!(::typesize::TypeSize::extra_size), &ident, pass_mode)
145            }
146            FieldConfig::With(fn_path) => {
147                gen_call_with_arg(&fn_path.into_token_stream(), &ident, pass_mode)
148            }
149        },
150    )
151    .unwrap_or_else(|| Ok(quote!(0_usize)))
152}
153
154fn check_repr_packed(attrs: &[syn::Attribute]) -> bool {
155    fn is_valid_repr_for_packed(ident: &syn::Ident) -> bool {
156        // "packed may only be applied to the Rust and C representations."
157        // https://doc.rust-lang.org/reference/type-layout.html#the-alignment-modifiers
158        ident == "C" || ident == "Rust"
159    }
160
161    struct CheckIsPacked(bool);
162    impl syn::parse::Parse for CheckIsPacked {
163        fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
164            let first_token = input.parse::<syn::Ident>()?;
165            if !input.peek(syn::Token![,]) {
166                let is_packed = first_token == "packed";
167                return Ok(Self(is_packed));
168            }
169
170            input.parse::<syn::Token![,]>()?;
171
172            let second_token = input.parse::<syn::Ident>()?;
173            if is_valid_repr_for_packed(&first_token) && second_token == "packed" {
174                return Ok(Self(true));
175            }
176
177            if first_token == "packed" && is_valid_repr_for_packed(&second_token) {
178                return Ok(Self(true));
179            }
180
181            Ok(Self(false))
182        }
183    }
184
185    attrs.iter().any(|attr| {
186        let syn::Meta::List(meta) = &attr.meta else {
187            return false;
188        };
189
190        let Some(ident) = meta.path.get_ident() else {
191            return false;
192        };
193
194        if ident != "repr" {
195            return false;
196        }
197
198        syn::parse2::<CheckIsPacked>(meta.tokens.clone()).unwrap().0
199    })
200}
201
202fn get_field_config(attrs: &[syn::Attribute]) -> syn::Result<FieldConfig> {
203    // based on
204    // https://docs.rs/syn/latest/syn/macro.custom_keyword.html#example
205
206    mod kw {
207        syn::custom_keyword!(skip);
208        syn::custom_keyword!(with);
209    }
210
211    enum Input {
212        Skip {
213            _skip: kw::skip,
214        },
215        With {
216            _with: kw::with,
217            _eq: syn::Token![=],
218            path: syn::Path,
219        },
220    }
221
222    impl syn::parse::Parse for Input {
223        fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
224            let lookahead = input.lookahead1();
225            if lookahead.peek(kw::skip) {
226                Ok(Self::Skip {
227                    _skip: input.parse()?,
228                })
229            } else if lookahead.peek(kw::with) {
230                Ok(Self::With {
231                    _with: input.parse()?,
232                    _eq: input.parse()?,
233                    path: input.parse()?,
234                })
235            } else {
236                Err(lookahead.error())
237            }
238        }
239    }
240
241    for attr in attrs {
242        let syn::Meta::List(meta) = &attr.meta else {
243            continue;
244        };
245
246        let Some(path) = meta.path.get_ident() else {
247            continue;
248        };
249
250        if path != "typesize" {
251            continue;
252        }
253
254        let input = syn::parse::<Input>(meta.tokens.clone().into())?;
255        return Ok(match input {
256            Input::Skip { .. } => FieldConfig::Skip,
257            Input::With { path, .. } => FieldConfig::With(path),
258        });
259    }
260
261    Ok(FieldConfig::Default)
262}
263
264struct GenerationRet {
265    extra_size: TokenStream,
266    #[cfg(feature = "details")]
267    details: Option<TokenStream>,
268}
269
270/// Implements `TypeSize` automatically for a `struct` or `enum`.
271///
272/// Use `#[typesize(skip)]` on a field to assume it does not manage any external memory.
273///
274/// Use `#[typesize(with = path::to::fn)]` on a field to specify a custom `extra_size` for that field.
275/// The function accepts a reference to the type of the field.
276///
277/// This will avoid requiring `TypeSize` to be implemented for this field, however may lead to undercounted results if the assumption does not hold.
278///
279/// # Struct Mode
280///
281/// `TypeSize::extra_size` will be calculated by adding up the `extra_size` of all fields.
282///
283/// # Enum Mode
284///
285/// `TypeSize::extra_size` will be calculated by adding up the `extra_size` of all of the fields of the active enum variant.
286///
287/// # Union Mode
288///
289/// Unions are unsupported as there is no safe way to calculate the `extra_size`, implement `typesize::TypeSize` manually.
290#[proc_macro_derive(TypeSize, attributes(typesize))]
291pub fn typesize_derive(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
292    let DeriveInput {
293        attrs,
294        vis: _,
295        ident,
296        generics,
297        data,
298    } = parse_macro_input!(tokens as DeriveInput);
299
300    let is_packed = check_repr_packed(&attrs);
301
302    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
303    let bodies = match data {
304        syn::Data::Struct(data) => gen_struct(&data.fields, is_packed),
305        syn::Data::Enum(data) => gen_enum(data.variants.into_iter(), is_packed),
306        syn::Data::Union(data) => Err(syn::Error::new(
307            data.union_token.span,
308            "Unions are unsupported for typesize derive.",
309        )),
310    };
311
312    let bodies = match bodies {
313        Ok(bodies) => bodies,
314        Err(err) => {
315            return err.into_compile_error().into();
316        }
317    };
318
319    let extra_size = bodies.extra_size;
320    #[cfg_attr(not(feature = "details"), allow(unused_mut))]
321    let mut impl_body = quote!(
322        fn extra_size(&self) -> usize {
323            #extra_size
324        }
325    );
326
327    #[cfg(feature = "details")]
328    if let Some(details) = bodies.details {
329        impl_body = quote!(
330            #impl_body
331
332            fn get_size_details(&self) -> Vec<::typesize::Field> {
333                #details
334            }
335        );
336    }
337
338    let output = quote! {
339        #[automatically_derived]
340        impl #impl_generics ::typesize::TypeSize for #ident #ty_generics #where_clause {
341            #impl_body
342        }
343    };
344
345    output.into()
346}