shrink_to_fit_macro/
lib.rs

1use quote::{quote, ToTokens};
2use syn::{spanned::Spanned, Attribute, Expr, Ident, Lit, Meta};
3
4#[proc_macro_derive(ShrinkToFit, attributes(shrink_to_fit))]
5pub fn derive_shrink_to_fit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
6    let input: syn::DeriveInput = syn::parse_macro_input!(input as syn::DeriveInput);
7    let type_attr: TypeAttr = TypeAttr::parse(&input.attrs);
8
9    let crate_name = type_attr
10        .crate_name
11        .as_ref()
12        .map(|q| q.to_token_stream())
13        .unwrap_or_else(|| quote!(shrink_to_fit));
14
15    let name = &input.ident;
16    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
17
18    let body_impl = match &input.data {
19        syn::Data::Struct(s) => {
20            let (field_bindings, body_code) = expand_fields(&type_attr, &s.fields);
21
22            quote!(
23                match self {
24                    Self { #field_bindings } => {
25                        #body_code
26                    }
27                }
28            )
29        }
30
31        syn::Data::Enum(e) => {
32            let mut arms = proc_macro2::TokenStream::new();
33
34            for v in e.variants.iter() {
35                let variant_name = &v.ident;
36
37                let (field_bindings, body_code) = expand_fields(&type_attr, &v.fields);
38
39                arms.extend(quote!(
40                    Self::#variant_name { #field_bindings } => {
41                        #body_code
42                    },
43                ));
44            }
45
46            quote!(
47                match self {
48                    #arms
49                }
50            )
51        }
52
53        syn::Data::Union(_) => {
54            panic!("union is not supported");
55        }
56    };
57
58    quote! {
59        impl<#impl_generics> #crate_name::ShrinkToFit for #name<#ty_generics> #where_clause {
60            fn shrink_to_fit(&mut self) {
61                #body_impl
62            }
63        }
64    }
65    .into()
66}
67
68#[derive(Default)]
69struct TypeAttr {
70    crate_name: Option<syn::Path>,
71}
72impl TypeAttr {
73    fn parse(attrs: &[Attribute]) -> TypeAttr {
74        let mut data_attr = TypeAttr::default();
75
76        for attr in attrs {
77            if attr.path().is_ident("shrink_to_fit") {
78                if let Meta::List(meta) = &attr.meta {
79                    let tokens = meta.tokens.clone();
80                    let kv = syn::parse2::<syn::MetaNameValue>(tokens).unwrap();
81
82                    if kv.path.is_ident("crate") {
83                        if let Expr::Lit(syn::ExprLit {
84                            lit: Lit::Str(s), ..
85                        }) = &kv.value
86                        {
87                            let path = syn::parse_str::<syn::Path>(&s.value()).unwrap();
88                            data_attr.crate_name = Some(path);
89                        }
90                    }
91                }
92            }
93        }
94
95        data_attr
96    }
97}
98
99/// Returns `(field_bindings, body_code)`
100fn expand_fields(
101    type_attr: &TypeAttr,
102    fields: &syn::Fields,
103) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
104    let crate_name = type_attr
105        .crate_name
106        .as_ref()
107        .map(|q| q.to_token_stream())
108        .unwrap_or_else(|| quote!(shrink_to_fit));
109
110    let mut field_bindings = proc_macro2::TokenStream::new();
111    let mut body_impl = proc_macro2::TokenStream::new();
112
113    match fields {
114        syn::Fields::Named(fields) => {
115            for field in fields.named.iter() {
116                let field_name = field.ident.as_ref().unwrap();
117
118                field_bindings.extend(quote!(
119                    ref mut #field_name,
120                ));
121
122                body_impl.extend(quote!(
123                    #crate_name::helpers::ShrinkToFitDerefSpecialization::new(#field_name).shrink_to_fit();
124                ));
125            }
126        }
127
128        syn::Fields::Unnamed(fields) => {
129            for (i, field) in fields.unnamed.iter().enumerate() {
130                let field_name = Ident::new(&format!("_{}", i), field.span());
131
132                body_impl.extend(quote!(
133                    #crate_name::helpers::ShrinkToFitDerefSpecialization::new(#field_name).shrink_to_fit();
134                ));
135
136                let index = syn::Index::from(i);
137                field_bindings.extend(quote!(
138                    #index: ref mut #field_name,
139                ));
140            }
141        }
142
143        syn::Fields::Unit => {}
144    }
145
146    (field_bindings, body_impl)
147}