Skip to main content

type_signature_derive/
lib.rs

1//! Derive macros for `type-signature` crate.
2
3use std::collections::HashSet;
4
5use proc_macro::TokenStream as TokenStream1;
6
7use proc_macro2::{Span, TokenStream};
8use quote::quote;
9use syn::{DeriveInput, Path, parse_macro_input};
10
11/// A struct collecting all info needed for [`derive_type_signature`].
12struct TypeSignatureImpl {
13    /// The identifier for the target type.
14    ident: syn::Ident,
15    /// Any generics on the target type.
16    generics: syn::Generics,
17    /// Extra `FieldTy: TypeSignature` bounds derived from the types of the (non-skipped)
18    /// fields.
19    ///
20    /// Used in addition to the unconditional `T: TypeSignature` bound on every generic type
21    /// parameter, to cover user-defined generic types whose `TypeSignature` impl carries extra
22    /// trait bounds (e.g. `MyWrapper<T> where T: SomeTrait + TypeSignature`). Only populated when
23    /// the type has at least one generic type parameter, to avoid adding a `SomeConcreteTy:
24    /// TypeSignature` bound that could effectively disable the impl.
25    generic_constraints: Vec<syn::Type>,
26    /// The list of variants for this type.
27    ///
28    /// For a struct, there is only one variant, but an enum may have multiple.
29    variants: Vec<TokenStream>,
30    /// If `Some`, override the name emitted into the signature (from `#[type_signature(rename = "...")]`).
31    rename: Option<String>,
32    /// The path to use for accessing the `type_signature` crate.
33    crate_path: Path,
34}
35impl TryFrom<DeriveInput> for TypeSignatureImpl {
36    type Error = syn::Error;
37
38    fn try_from(ast: DeriveInput) -> syn::Result<Self> {
39        let type_attrs = TypeAttrs::parse(&ast.attrs)?;
40        let crate_path = type_attrs.crate_path.unwrap_or_else(|| Path {
41            leading_colon: Some(syn::token::PathSep(Span::call_site())),
42            segments: {
43                let mut segments = syn::punctuated::Punctuated::new();
44                segments.push(syn::Ident::new("type_signature", Span::call_site()).into());
45                segments
46            },
47        });
48        for param in &ast.generics.params {
49            if let syn::GenericParam::Const(const_param) = param {
50                let is_ident = matches!(
51                    &const_param.ty,
52                    syn::Type::Path(syn::TypePath { qself: None, path })
53                        if path.get_ident().is_some()
54                );
55                if !is_ident {
56                    return Err(syn::Error::new_spanned(
57                        &const_param.ty,
58                        "TypeSignature derive only supports const generic parameters whose type is a simple identifier (e.g. `usize`, `bool`)",
59                    ));
60                }
61            }
62        }
63        let any_generic_tys = ast
64            .generics
65            .params
66            .iter()
67            .any(|param| matches!(param, syn::GenericParam::Type(_)));
68        let (variants, generic_constraints) = match ast.data {
69            syn::Data::Struct(st) => {
70                let (field_impls, field_tys) = fields_info(&st.fields, &crate_path)?;
71                let variants = vec![quote!(("", &[ #( #field_impls ),* ]))];
72                (variants, field_tys)
73            }
74            syn::Data::Enum(en) => {
75                let rows = en
76                    .variants
77                    .iter()
78                    .map(|variant| -> syn::Result<_> {
79                        let variant_attrs = TypeAttrs::parse(&variant.attrs)?;
80                        let variant_name = variant_attrs
81                            .rename
82                            .unwrap_or_else(|| variant.ident.to_string());
83                        let (field_impls, field_tys) = fields_info(&variant.fields, &crate_path)?;
84                        let variant_impl = quote!((#variant_name, &[ #( #field_impls ),* ]));
85                        Ok((variant_impl, field_tys))
86                    })
87                    .collect::<syn::Result<Vec<_>>>()?;
88                let (variants, per_variant_field_tys): (Vec<_>, Vec<_>) = rows.into_iter().unzip();
89                let field_tys = deduplicate(per_variant_field_tys.into_iter().flatten()).collect();
90                (variants, field_tys)
91            }
92            syn::Data::Union(un) => un
93                .fields
94                .named
95                .iter()
96                .filter_map(|field| {
97                    let attrs = match FieldAttrs::parse(&field.attrs) {
98                        Ok(a) => a,
99                        Err(e) => return Some(Err(e)),
100                    };
101                    if attrs.skip {
102                        return None;
103                    }
104                    let name = attrs.rename.unwrap_or_else(|| {
105                        field
106                            .ident
107                            .as_ref()
108                            .expect("union fields are always named")
109                            .to_string()
110                    });
111                    let ty = &field.ty;
112                    let variant = quote!(
113                        (#name, &[("", &<#ty as #crate_path::TypeSignature>::SIGNATURE)])
114                    );
115                    Some(Ok((variant, field.ty.clone())))
116                })
117                .collect::<syn::Result<Vec<_>>>()?
118                .into_iter()
119                .unzip(),
120        };
121        // Only supply generic constraints if there's a generic type.
122        let generic_constraints = if any_generic_tys {
123            generic_constraints
124        } else {
125            Vec::new()
126        };
127        Ok(Self {
128            ident: ast.ident,
129            generics: ast.generics,
130            generic_constraints,
131            variants,
132            rename: type_attrs.rename,
133            crate_path,
134        })
135    }
136}
137impl quote::ToTokens for TypeSignatureImpl {
138    fn to_tokens(&self, tokens: &mut TokenStream) {
139        tokens.extend(self.to_token_stream());
140    }
141
142    fn to_token_stream(&self) -> TokenStream {
143        let (impl_generics, ty_generics, _) = self.generics.split_for_impl();
144        // Extract the raw predicates (without the leading `where` keyword) so we can merge
145        // them with our own `FieldTy: TypeSignature` bounds under a single `where` clause.
146        let user_where_predicates: Vec<&syn::WherePredicate> = self
147            .generics
148            .where_clause
149            .as_ref()
150            .map(|wc| wc.predicates.iter().collect())
151            .unwrap_or_default();
152        let ident = &self.ident;
153        let ty_name = self
154            .rename
155            .clone()
156            .unwrap_or_else(|| self.ident.to_string());
157        let generic_constraints = &self.generic_constraints;
158        let variants = &self.variants;
159        let crate_path = &self.crate_path;
160        // Every generic type parameter is unconditionally bounded by `TypeSignature`.
161        // This covers cases where the parameter appears only in `ty_generics` (e.g. empty
162        // enums, or structs where every generic-typed field is `#[type_signature(skip)]`).
163        let generic_ty_bounds = self.generics.params.iter().filter_map(|param| {
164            if let syn::GenericParam::Type(ty) = param {
165                let ident = &ty.ident;
166                Some(quote!(#ident: #crate_path::TypeSignature))
167            } else {
168                None
169            }
170        });
171        let generic_ty_signatures = self.generics.params.iter().filter_map(|param| {
172            if let syn::GenericParam::Type(ty) = param {
173                let ident = &ty.ident;
174                Some(quote!(&<#ident as #crate_path::TypeSignature>::SIGNATURE))
175            } else {
176                None
177            }
178        });
179        let const_generic_signatures = self.generics.params.iter().filter_map(|param| {
180            if let syn::GenericParam::Const(const_param) = param {
181                let syn::Type::Path(syn::TypePath { qself: None, path }) = &const_param.ty else {
182                    unreachable!("validated in TryFrom::try_from")
183                };
184                let param_ty = path
185                    .get_ident()
186                    .expect("validated in TryFrom::try_from")
187                    .to_string();
188                let hash_fn_name =
189                    syn::Ident::new(&format!("hash_const_{param_ty}"), Span::call_site());
190                let param_val = &const_param.ident;
191                let param_name = const_param.ident.to_string();
192                Some(quote! { const {
193                    let mut acc = #crate_path::__macro_export::hash_str(#param_name);
194                    #crate_path::__macro_export::mix_values(
195                        &mut acc,
196                        #crate_path::__macro_export::#hash_fn_name(#param_val)
197                    );
198                    acc
199                }})
200            } else {
201                None
202            }
203        });
204        quote! {
205            impl #impl_generics #crate_path::TypeSignature for #ident #ty_generics
206                where
207                    #( #user_where_predicates, )*
208                    #( #generic_ty_bounds, )*
209                    #( #generic_constraints: #crate_path::TypeSignature ),*
210            {
211                #![allow(single_use_lifetimes, reason = "Macro-generated code")]
212                const SIGNATURE: #crate_path::TypeSignatureHasher = #crate_path::TypeSignatureHasher {
213                    ty_name: #ty_name,
214                    ty_generics: &[ #( #generic_ty_signatures ),* ],
215                    const_generic_hashes: &[ #( #const_generic_signatures ),* ],
216                    variants: &[ #( #variants ),* ],
217                };
218            }
219        }
220    }
221}
222
223/// Derive macro for `TypeSignature`.
224#[proc_macro_derive(TypeSignature, attributes(type_signature))]
225pub fn derive_type_signature(input: TokenStream1) -> TokenStream1 {
226    let ast = parse_macro_input!(input as DeriveInput);
227    match TypeSignatureImpl::try_from(ast) {
228        Ok(imp) => quote!(#imp),
229        Err(e) => e.into_compile_error(),
230    }
231    .into()
232}
233
234/// Deduplicate an iterator while preserving the order the elements first appear.
235fn deduplicate<T: core::hash::Hash + Eq + Clone>(
236    elems: impl IntoIterator<Item = T>,
237) -> impl Iterator<Item = T> {
238    let mut seen = HashSet::new();
239    elems.into_iter().filter(move |ty| seen.insert(ty.clone()))
240}
241
242/// Build `(field_impl_tokens, field_type)` pairs for every field, covering unit/named/tuple shapes.
243///
244/// Fields marked `#[type_signature(skip)]` are omitted from both vectors.
245fn fields_info(
246    fields: &syn::Fields,
247    crate_path: &Path,
248) -> syn::Result<(Vec<TokenStream>, Vec<syn::Type>)> {
249    let rows = fields
250        .iter()
251        .enumerate()
252        .filter_map(|(idx, field)| {
253            let attrs = match FieldAttrs::parse(&field.attrs) {
254                Ok(a) => a,
255                Err(e) => return Some(Err(e)),
256            };
257            if attrs.skip {
258                return None;
259            }
260            let name = attrs.rename.unwrap_or_else(|| {
261                field
262                    .ident
263                    .as_ref()
264                    .map_or_else(|| idx.to_string(), syn::Ident::to_string)
265            });
266            let ty = &field.ty;
267            let impl_tokens = quote!((#name, &<#ty as #crate_path::TypeSignature>::SIGNATURE));
268            Some(Ok((impl_tokens, field.ty.clone())))
269        })
270        .collect::<syn::Result<Vec<_>>>()?;
271    Ok(rows.into_iter().unzip())
272}
273
274/// Parsed `#[type_signature(...)]` attributes at the type level.
275#[derive(Default)]
276struct TypeAttrs {
277    /// Override the name used in the signature. Lets the signature survive a type rename.
278    rename: Option<String>,
279    /// The path to the `type_signature` crate if it needs to be overridden.
280    crate_path: Option<Path>,
281}
282
283impl TypeAttrs {
284    fn parse(attrs: &[syn::Attribute]) -> syn::Result<Self> {
285        let mut out = Self::default();
286        for attr in attrs {
287            if !attr.path().is_ident("type_signature") {
288                continue;
289            }
290            attr.parse_nested_meta(|meta| {
291                if meta.path.is_ident("rename") {
292                    let lit: syn::LitStr = meta.value()?.parse()?;
293                    out.rename = Some(lit.value());
294                    Ok(())
295                } else if meta.path.is_ident("crate") {
296                    let crate_path: Path = meta.value()?.parse()?;
297                    out.crate_path = Some(crate_path);
298                    Ok(())
299                } else {
300                    Err(meta.error("unrecognized type_signature attribute {attr:?}"))
301                }
302            })?;
303        }
304        Ok(out)
305    }
306}
307
308/// Parsed `#[type_signature(...)]` attributes at the field level.
309#[derive(Default)]
310struct FieldAttrs {
311    /// Omit this field from the signature entirely.
312    skip: bool,
313    /// Override the name used for this field in the signature.
314    rename: Option<String>,
315}
316
317impl FieldAttrs {
318    fn parse(attrs: &[syn::Attribute]) -> syn::Result<Self> {
319        let mut out = Self::default();
320        for attr in attrs {
321            if !attr.path().is_ident("type_signature") {
322                continue;
323            }
324            attr.parse_nested_meta(|meta| {
325                if meta.path.is_ident("skip") {
326                    out.skip = true;
327                    Ok(())
328                } else if meta.path.is_ident("rename") {
329                    let lit: syn::LitStr = meta.value()?.parse()?;
330                    out.rename = Some(lit.value());
331                    Ok(())
332                } else {
333                    Err(meta.error(
334                        "unrecognized type_signature attribute; expected `skip` or `rename = \"...\"`",
335                    ))
336                }
337            })?;
338        }
339        Ok(out)
340    }
341}