Skip to main content

typeshift_derive/
lib.rs

1//! Proc macros for `typeshift`.
2//!
3//! `#[typeshift]` is the primary entry point. It augments a struct/enum with
4//! derives and helper attributes required by `serde`, `validator`, and
5//! `schemars`.
6
7use proc_macro::TokenStream;
8use quote::quote;
9use syn::{Attribute, Item, parse_macro_input};
10
11#[proc_macro_attribute]
12pub fn typeshift(_attr: TokenStream, item: TokenStream) -> TokenStream {
13    let mut item = parse_macro_input!(item as Item);
14
15    match &mut item {
16        Item::Struct(input) => {
17            apply_typeshift_attrs(&mut input.attrs, true);
18            quote!(#input).into()
19        }
20        Item::Enum(input) => {
21            apply_typeshift_attrs(&mut input.attrs, false);
22
23            let ident = &input.ident;
24            let generics = &input.generics;
25            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
26            let validate_impl = if has_derived_trait(&input.attrs, "Validate") {
27                quote! {}
28            } else {
29                quote! {
30                    impl #impl_generics ::typeshift::validator::Validate for #ident #ty_generics #where_clause {
31                        fn validate(&self) -> ::core::result::Result<(), ::typeshift::validator::ValidationErrors> {
32                            ::core::result::Result::Ok(())
33                        }
34                    }
35                }
36            };
37
38            quote! {
39                #input
40                #validate_impl
41            }
42            .into()
43        }
44        _ => syn::Error::new_spanned(item, "#[typeshift] supports structs and enums only")
45            .to_compile_error()
46            .into(),
47    }
48}
49
50#[proc_macro_derive(TypeShift, attributes(validate, serde, schemars))]
51/// Legacy compatibility derive.
52///
53/// This derive intentionally generates no code. Use `#[typeshift]` as the
54/// primary macro entry point.
55pub fn derive_typeshift(_input: TokenStream) -> TokenStream {
56    TokenStream::new()
57}
58
59fn apply_typeshift_attrs(attrs: &mut Vec<Attribute>, include_validate: bool) {
60    let mut required = vec!["Serialize", "Deserialize", "JsonSchema"];
61    if include_validate {
62        required.push("Validate");
63    }
64    add_missing_derives(attrs, &required);
65    ensure_attr(attrs, "serde", "crate = \"typeshift::serde\"");
66    ensure_attr(attrs, "schemars", "crate = \"typeshift::schemars\"");
67    if include_validate {
68        ensure_attr(attrs, "validate", "crate = \"typeshift::validator\"");
69    }
70}
71
72fn has_derived_trait(attrs: &[Attribute], trait_name: &str) -> bool {
73    attrs
74        .iter()
75        .filter(|attr| attr.path().is_ident("derive"))
76        .filter_map(|attr| {
77            attr.parse_args_with(
78                syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
79            )
80            .ok()
81        })
82        .flat_map(|paths| paths.into_iter())
83        .any(|path| {
84            path.segments
85                .last()
86                .map(|seg| seg.ident == trait_name)
87                .unwrap_or(false)
88        })
89}
90
91fn add_missing_derives(attrs: &mut Vec<Attribute>, required: &[&str]) {
92    let mut missing = Vec::new();
93    for name in required {
94        if has_derived_trait(attrs, name) {
95            continue;
96        }
97        let path: syn::Path = match *name {
98            "Serialize" => syn::parse_quote!(::typeshift::serde::Serialize),
99            "Deserialize" => syn::parse_quote!(::typeshift::serde::Deserialize),
100            "Validate" => syn::parse_quote!(::typeshift::validator::Validate),
101            "JsonSchema" => syn::parse_quote!(::typeshift::schemars::JsonSchema),
102            _ => continue,
103        };
104        missing.push(path);
105    }
106
107    if !missing.is_empty() {
108        let insert_at = attrs
109            .iter()
110            .rposition(|attr| attr.path().is_ident("derive"))
111            .map(|index| index + 1)
112            .unwrap_or(0);
113        attrs.insert(insert_at, syn::parse_quote!(#[derive(#(#missing),*)]));
114    }
115}
116
117fn ensure_attr(attrs: &mut Vec<Attribute>, name: &str, args: &str) {
118    let path = syn::Ident::new(name, proc_macro2::Span::call_site());
119    let args: proc_macro2::TokenStream = match args.parse() {
120        Ok(args) => args,
121        Err(_) => return,
122    };
123
124    let has_crate_arg = attrs
125        .iter()
126        .any(|attr| attr.path().is_ident(name) && attr_has_crate_arg(attr));
127
128    if !has_crate_arg {
129        attrs.push(syn::parse_quote!(#[#path(#args)]));
130    }
131}
132
133fn attr_has_crate_arg(attr: &Attribute) -> bool {
134    attr.parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
135        .map(|metas| {
136            metas.into_iter().any(|meta| {
137                if let syn::Meta::NameValue(name_value) = meta {
138                    return name_value.path.is_ident("crate");
139                }
140                false
141            })
142        })
143        .unwrap_or(false)
144}