1use proc_macro::TokenStream;
8use quote::quote;
9use syn::{Attribute, Item, parse_macro_input};
10
11#[proc_macro_attribute]
12pub fn typeshift(_attr: TokenStream, item: TokenStream) -> TokenStream {
13 let mut item = parse_macro_input!(item as Item);
14
15 match &mut item {
16 Item::Struct(input) => {
17 apply_typeshift_attrs(&mut input.attrs, true);
18 quote!(#input).into()
19 }
20 Item::Enum(input) => {
21 apply_typeshift_attrs(&mut input.attrs, false);
22
23 let ident = &input.ident;
24 let generics = &input.generics;
25 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
26 let validate_impl = if has_derived_trait(&input.attrs, "Validate") {
27 quote! {}
28 } else {
29 quote! {
30 impl #impl_generics ::typeshift::validator::Validate for #ident #ty_generics #where_clause {
31 fn validate(&self) -> ::core::result::Result<(), ::typeshift::validator::ValidationErrors> {
32 ::core::result::Result::Ok(())
33 }
34 }
35 }
36 };
37
38 quote! {
39 #input
40 #validate_impl
41 }
42 .into()
43 }
44 _ => syn::Error::new_spanned(item, "#[typeshift] supports structs and enums only")
45 .to_compile_error()
46 .into(),
47 }
48}
49
50#[proc_macro_derive(TypeShift, attributes(validate, serde, schemars))]
51pub fn derive_typeshift(_input: TokenStream) -> TokenStream {
56 TokenStream::new()
57}
58
59fn apply_typeshift_attrs(attrs: &mut Vec<Attribute>, include_validate: bool) {
60 let mut required = vec!["Serialize", "Deserialize", "JsonSchema"];
61 if include_validate {
62 required.push("Validate");
63 }
64 add_missing_derives(attrs, &required);
65 ensure_attr(attrs, "serde", "crate = \"typeshift::serde\"");
66 ensure_attr(attrs, "schemars", "crate = \"typeshift::schemars\"");
67 if include_validate {
68 ensure_attr(attrs, "validate", "crate = \"typeshift::validator\"");
69 }
70}
71
72fn has_derived_trait(attrs: &[Attribute], trait_name: &str) -> bool {
73 attrs
74 .iter()
75 .filter(|attr| attr.path().is_ident("derive"))
76 .filter_map(|attr| {
77 attr.parse_args_with(
78 syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
79 )
80 .ok()
81 })
82 .flat_map(|paths| paths.into_iter())
83 .any(|path| {
84 path.segments
85 .last()
86 .map(|seg| seg.ident == trait_name)
87 .unwrap_or(false)
88 })
89}
90
91fn add_missing_derives(attrs: &mut Vec<Attribute>, required: &[&str]) {
92 let mut missing = Vec::new();
93 for name in required {
94 if has_derived_trait(attrs, name) {
95 continue;
96 }
97 let path: syn::Path = match *name {
98 "Serialize" => syn::parse_quote!(::typeshift::serde::Serialize),
99 "Deserialize" => syn::parse_quote!(::typeshift::serde::Deserialize),
100 "Validate" => syn::parse_quote!(::typeshift::validator::Validate),
101 "JsonSchema" => syn::parse_quote!(::typeshift::schemars::JsonSchema),
102 _ => continue,
103 };
104 missing.push(path);
105 }
106
107 if !missing.is_empty() {
108 let insert_at = attrs
109 .iter()
110 .rposition(|attr| attr.path().is_ident("derive"))
111 .map(|index| index + 1)
112 .unwrap_or(0);
113 attrs.insert(insert_at, syn::parse_quote!(#[derive(#(#missing),*)]));
114 }
115}
116
117fn ensure_attr(attrs: &mut Vec<Attribute>, name: &str, args: &str) {
118 let path = syn::Ident::new(name, proc_macro2::Span::call_site());
119 let args: proc_macro2::TokenStream = match args.parse() {
120 Ok(args) => args,
121 Err(_) => return,
122 };
123
124 let has_crate_arg = attrs
125 .iter()
126 .any(|attr| attr.path().is_ident(name) && attr_has_crate_arg(attr));
127
128 if !has_crate_arg {
129 attrs.push(syn::parse_quote!(#[#path(#args)]));
130 }
131}
132
133fn attr_has_crate_arg(attr: &Attribute) -> bool {
134 attr.parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
135 .map(|metas| {
136 metas.into_iter().any(|meta| {
137 if let syn::Meta::NameValue(name_value) = meta {
138 return name_value.path.is_ident("crate");
139 }
140 false
141 })
142 })
143 .unwrap_or(false)
144}