Skip to main content

typeshift_derive/
lib.rs

1//! Proc macros for `typeshift`.
2//!
3//! `#[typeshift]` is the primary entry point. It augments a struct/enum with
4//! derives and helper attributes required by `serde`, `validator`, and
5//! `schemars`.
6
7use proc_macro::TokenStream;
8use quote::{format_ident, quote};
9use syn::{Attribute, Fields, Item, ItemEnum, parse_macro_input};
10
11#[proc_macro_attribute]
12pub fn typeshift(_attr: TokenStream, item: TokenStream) -> TokenStream {
13    let mut item = parse_macro_input!(item as Item);
14
15    match &mut item {
16        Item::Struct(input) => {
17            apply_typeshift_attrs(&mut input.attrs, true);
18            quote!(#input).into()
19        }
20        Item::Enum(input) => {
21            apply_typeshift_attrs(&mut input.attrs, false);
22
23            let validate_impl = build_enum_validate_impl(input);
24
25            quote! {
26                #input
27                #validate_impl
28            }
29            .into()
30        }
31        _ => syn::Error::new_spanned(item, "#[typeshift] supports structs and enums only")
32            .to_compile_error()
33            .into(),
34    }
35}
36
37fn build_enum_validate_impl(input: &ItemEnum) -> proc_macro2::TokenStream {
38    if has_derived_trait(&input.attrs, "Validate") {
39        return quote! {};
40    }
41
42    let ident = &input.ident;
43    let generics = &input.generics;
44    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
45    let helper_generics_def = helper_def_generics(generics);
46    let helper_generics_use = helper_use_generics(generics);
47
48    let helper_defs = input.variants.iter().filter_map(|variant| {
49        let variant_ident = &variant.ident;
50        let helper_ident = format_ident!("__TypeShiftValidate{}{}", ident, variant_ident);
51        match &variant.fields {
52            Fields::Unit => None,
53            Fields::Named(fields) => {
54                let defs = fields.named.iter().map(|field| {
55                    let attrs = validate_attrs(&field.attrs);
56                    let name = match &field.ident {
57                        Some(name) => name,
58                        None => unreachable!("named field must have ident"),
59                    };
60                    let ty = &field.ty;
61                    quote! { #(#attrs)* #name: &'__typeshift_enum_validate #ty }
62                });
63
64                Some(quote! {
65                    #[allow(dead_code)]
66                    #[derive(::typeshift::validator::Validate)]
67                    #[validate(crate = "typeshift::validator")]
68                    struct #helper_ident #helper_generics_def #where_clause {
69                        #(#defs,)*
70                    }
71                })
72            }
73            Fields::Unnamed(fields) => {
74                let defs = fields.unnamed.iter().enumerate().map(|(idx, field)| {
75                    let attrs = validate_attrs(&field.attrs);
76                    let name = format_ident!("__field_{idx}");
77                    let ty = &field.ty;
78                    quote! { #(#attrs)* #name: &'__typeshift_enum_validate #ty }
79                });
80
81                Some(quote! {
82                    #[allow(dead_code)]
83                    #[derive(::typeshift::validator::Validate)]
84                    #[validate(crate = "typeshift::validator")]
85                    struct #helper_ident #helper_generics_def #where_clause {
86                        #(#defs,)*
87                    }
88                })
89            }
90        }
91    });
92
93    let arms = input.variants.iter().map(|variant| {
94        let variant_ident = &variant.ident;
95        let helper_ident = format_ident!("__TypeShiftValidate{}{}", ident, variant_ident);
96        match &variant.fields {
97            Fields::Unit => {
98                quote! {
99                    Self::#variant_ident => ::core::result::Result::Ok(())
100                }
101            }
102            Fields::Named(fields) => {
103                let names: Vec<_> = fields
104                    .named
105                    .iter()
106                    .filter_map(|field| field.ident.as_ref())
107                    .collect();
108                quote! {
109                    Self::#variant_ident { #(#names,)* } => {
110                        let helper = #helper_ident #helper_generics_use { #(#names,)* };
111                        ::typeshift::validator::Validate::validate(&helper)
112                    }
113                }
114            }
115            Fields::Unnamed(fields) => {
116                let bindings: Vec<_> = fields
117                    .unnamed
118                    .iter()
119                    .enumerate()
120                    .map(|(idx, _)| format_ident!("__field_{idx}"))
121                    .collect();
122                let init_fields = bindings.iter().map(|name| quote! { #name: #name });
123                quote! {
124                    Self::#variant_ident( #(#bindings,)* ) => {
125                        let helper = #helper_ident #helper_generics_use { #(#init_fields,)* };
126                        ::typeshift::validator::Validate::validate(&helper)
127                    }
128                }
129            }
130        }
131    });
132
133    quote! {
134        #(#helper_defs)*
135
136        impl #impl_generics ::typeshift::validator::Validate for #ident #ty_generics #where_clause {
137            fn validate(&self) -> ::core::result::Result<(), ::typeshift::validator::ValidationErrors> {
138                match self {
139                    #(#arms,)*
140                }
141            }
142        }
143    }
144}
145
146fn helper_def_generics(generics: &syn::Generics) -> proc_macro2::TokenStream {
147    let params = &generics.params;
148    if params.is_empty() {
149        quote! { <'__typeshift_enum_validate> }
150    } else {
151        quote! { <'__typeshift_enum_validate, #params> }
152    }
153}
154
155fn helper_use_generics(generics: &syn::Generics) -> proc_macro2::TokenStream {
156    let args: Vec<proc_macro2::TokenStream> = generics
157        .params
158        .iter()
159        .map(|param| match param {
160            syn::GenericParam::Type(ty) => {
161                let ident = &ty.ident;
162                quote! { #ident }
163            }
164            syn::GenericParam::Lifetime(lt) => {
165                let lifetime = &lt.lifetime;
166                quote! { #lifetime }
167            }
168            syn::GenericParam::Const(konst) => {
169                let ident = &konst.ident;
170                quote! { #ident }
171            }
172        })
173        .collect();
174
175    if args.is_empty() {
176        quote! { ::<'_> }
177    } else {
178        quote! { ::<'_, #(#args,)*> }
179    }
180}
181
182fn validate_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
183    attrs
184        .iter()
185        .filter(|attr| attr.path().is_ident("validate"))
186        .cloned()
187        .collect()
188}
189
190#[proc_macro_derive(TypeShift, attributes(validate, serde, schemars))]
191/// Legacy compatibility derive.
192///
193/// This derive intentionally generates no code. Use `#[typeshift]` as the
194/// primary macro entry point.
195pub fn derive_typeshift(_input: TokenStream) -> TokenStream {
196    TokenStream::new()
197}
198
199fn apply_typeshift_attrs(attrs: &mut Vec<Attribute>, include_validate: bool) {
200    let mut required = vec!["Serialize", "Deserialize", "JsonSchema"];
201    if include_validate {
202        required.push("Validate");
203    }
204    add_missing_derives(attrs, &required);
205    ensure_attr(attrs, "serde", "crate = \"typeshift::serde\"");
206    ensure_attr(attrs, "schemars", "crate = \"typeshift::schemars\"");
207    if include_validate {
208        ensure_attr(attrs, "validate", "crate = \"typeshift::validator\"");
209    }
210}
211
212fn has_derived_trait(attrs: &[Attribute], trait_name: &str) -> bool {
213    attrs
214        .iter()
215        .filter(|attr| attr.path().is_ident("derive"))
216        .filter_map(|attr| {
217            attr.parse_args_with(
218                syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
219            )
220            .ok()
221        })
222        .flat_map(|paths| paths.into_iter())
223        .any(|path| {
224            path.segments
225                .last()
226                .map(|seg| seg.ident == trait_name)
227                .unwrap_or(false)
228        })
229}
230
231fn add_missing_derives(attrs: &mut Vec<Attribute>, required: &[&str]) {
232    let mut missing = Vec::new();
233    for name in required {
234        if has_derived_trait(attrs, name) {
235            continue;
236        }
237        let path: syn::Path = match *name {
238            "Serialize" => syn::parse_quote!(::typeshift::serde::Serialize),
239            "Deserialize" => syn::parse_quote!(::typeshift::serde::Deserialize),
240            "Validate" => syn::parse_quote!(::typeshift::validator::Validate),
241            "JsonSchema" => syn::parse_quote!(::typeshift::schemars::JsonSchema),
242            _ => continue,
243        };
244        missing.push(path);
245    }
246
247    if !missing.is_empty() {
248        let insert_at = attrs
249            .iter()
250            .rposition(|attr| attr.path().is_ident("derive"))
251            .map(|index| index + 1)
252            .unwrap_or(0);
253        attrs.insert(insert_at, syn::parse_quote!(#[derive(#(#missing),*)]));
254    }
255}
256
257fn ensure_attr(attrs: &mut Vec<Attribute>, name: &str, args: &str) {
258    let path = syn::Ident::new(name, proc_macro2::Span::call_site());
259    let args: proc_macro2::TokenStream = match args.parse() {
260        Ok(args) => args,
261        Err(_) => return,
262    };
263
264    let has_crate_arg = attrs
265        .iter()
266        .any(|attr| attr.path().is_ident(name) && attr_has_crate_arg(attr));
267
268    if !has_crate_arg {
269        attrs.push(syn::parse_quote!(#[#path(#args)]));
270    }
271}
272
273fn attr_has_crate_arg(attr: &Attribute) -> bool {
274    attr.parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
275        .map(|metas| {
276            metas.into_iter().any(|meta| {
277                if let syn::Meta::NameValue(name_value) = meta {
278                    return name_value.path.is_ident("crate");
279                }
280                false
281            })
282        })
283        .unwrap_or(false)
284}