1use proc_macro::TokenStream;
8use quote::{format_ident, quote};
9use syn::{Attribute, Fields, Item, ItemEnum, parse_macro_input};
10
11#[proc_macro_attribute]
12pub fn typeshift(_attr: TokenStream, item: TokenStream) -> TokenStream {
13 let mut item = parse_macro_input!(item as Item);
14
15 match &mut item {
16 Item::Struct(input) => {
17 apply_typeshift_attrs(&mut input.attrs, true);
18 quote!(#input).into()
19 }
20 Item::Enum(input) => {
21 apply_typeshift_attrs(&mut input.attrs, false);
22
23 let validate_impl = build_enum_validate_impl(input);
24
25 quote! {
26 #input
27 #validate_impl
28 }
29 .into()
30 }
31 _ => syn::Error::new_spanned(item, "#[typeshift] supports structs and enums only")
32 .to_compile_error()
33 .into(),
34 }
35}
36
37fn build_enum_validate_impl(input: &ItemEnum) -> proc_macro2::TokenStream {
38 if has_derived_trait(&input.attrs, "Validate") {
39 return quote! {};
40 }
41
42 let ident = &input.ident;
43 let generics = &input.generics;
44 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
45 let helper_generics_def = helper_def_generics(generics);
46 let helper_generics_use = helper_use_generics(generics);
47
48 let helper_defs = input.variants.iter().filter_map(|variant| {
49 let variant_ident = &variant.ident;
50 let helper_ident = format_ident!("__TypeShiftValidate{}{}", ident, variant_ident);
51 match &variant.fields {
52 Fields::Unit => None,
53 Fields::Named(fields) => {
54 let defs = fields.named.iter().map(|field| {
55 let attrs = validate_attrs(&field.attrs);
56 let name = match &field.ident {
57 Some(name) => name,
58 None => unreachable!("named field must have ident"),
59 };
60 let ty = &field.ty;
61 quote! { #(#attrs)* #name: &'__typeshift_enum_validate #ty }
62 });
63
64 Some(quote! {
65 #[allow(dead_code)]
66 #[derive(::typeshift::validator::Validate)]
67 #[validate(crate = "typeshift::validator")]
68 struct #helper_ident #helper_generics_def #where_clause {
69 #(#defs,)*
70 }
71 })
72 }
73 Fields::Unnamed(fields) => {
74 let defs = fields.unnamed.iter().enumerate().map(|(idx, field)| {
75 let attrs = validate_attrs(&field.attrs);
76 let name = format_ident!("__field_{idx}");
77 let ty = &field.ty;
78 quote! { #(#attrs)* #name: &'__typeshift_enum_validate #ty }
79 });
80
81 Some(quote! {
82 #[allow(dead_code)]
83 #[derive(::typeshift::validator::Validate)]
84 #[validate(crate = "typeshift::validator")]
85 struct #helper_ident #helper_generics_def #where_clause {
86 #(#defs,)*
87 }
88 })
89 }
90 }
91 });
92
93 let arms = input.variants.iter().map(|variant| {
94 let variant_ident = &variant.ident;
95 let helper_ident = format_ident!("__TypeShiftValidate{}{}", ident, variant_ident);
96 match &variant.fields {
97 Fields::Unit => {
98 quote! {
99 Self::#variant_ident => ::core::result::Result::Ok(())
100 }
101 }
102 Fields::Named(fields) => {
103 let names: Vec<_> = fields
104 .named
105 .iter()
106 .filter_map(|field| field.ident.as_ref())
107 .collect();
108 quote! {
109 Self::#variant_ident { #(#names,)* } => {
110 let helper = #helper_ident #helper_generics_use { #(#names,)* };
111 ::typeshift::validator::Validate::validate(&helper)
112 }
113 }
114 }
115 Fields::Unnamed(fields) => {
116 let bindings: Vec<_> = fields
117 .unnamed
118 .iter()
119 .enumerate()
120 .map(|(idx, _)| format_ident!("__field_{idx}"))
121 .collect();
122 let init_fields = bindings.iter().map(|name| quote! { #name: #name });
123 quote! {
124 Self::#variant_ident( #(#bindings,)* ) => {
125 let helper = #helper_ident #helper_generics_use { #(#init_fields,)* };
126 ::typeshift::validator::Validate::validate(&helper)
127 }
128 }
129 }
130 }
131 });
132
133 quote! {
134 #(#helper_defs)*
135
136 impl #impl_generics ::typeshift::validator::Validate for #ident #ty_generics #where_clause {
137 fn validate(&self) -> ::core::result::Result<(), ::typeshift::validator::ValidationErrors> {
138 match self {
139 #(#arms,)*
140 }
141 }
142 }
143 }
144}
145
146fn helper_def_generics(generics: &syn::Generics) -> proc_macro2::TokenStream {
147 let params = &generics.params;
148 if params.is_empty() {
149 quote! { <'__typeshift_enum_validate> }
150 } else {
151 quote! { <'__typeshift_enum_validate, #params> }
152 }
153}
154
155fn helper_use_generics(generics: &syn::Generics) -> proc_macro2::TokenStream {
156 let args: Vec<proc_macro2::TokenStream> = generics
157 .params
158 .iter()
159 .map(|param| match param {
160 syn::GenericParam::Type(ty) => {
161 let ident = &ty.ident;
162 quote! { #ident }
163 }
164 syn::GenericParam::Lifetime(lt) => {
165 let lifetime = <.lifetime;
166 quote! { #lifetime }
167 }
168 syn::GenericParam::Const(konst) => {
169 let ident = &konst.ident;
170 quote! { #ident }
171 }
172 })
173 .collect();
174
175 if args.is_empty() {
176 quote! { ::<'_> }
177 } else {
178 quote! { ::<'_, #(#args,)*> }
179 }
180}
181
182fn validate_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
183 attrs
184 .iter()
185 .filter(|attr| attr.path().is_ident("validate"))
186 .cloned()
187 .collect()
188}
189
190#[proc_macro_derive(TypeShift, attributes(validate, serde, schemars))]
191pub fn derive_typeshift(_input: TokenStream) -> TokenStream {
196 TokenStream::new()
197}
198
199fn apply_typeshift_attrs(attrs: &mut Vec<Attribute>, include_validate: bool) {
200 let mut required = vec!["Serialize", "Deserialize", "JsonSchema"];
201 if include_validate {
202 required.push("Validate");
203 }
204 add_missing_derives(attrs, &required);
205 ensure_attr(attrs, "serde", "crate = \"typeshift::serde\"");
206 ensure_attr(attrs, "schemars", "crate = \"typeshift::schemars\"");
207 if include_validate {
208 ensure_attr(attrs, "validate", "crate = \"typeshift::validator\"");
209 }
210}
211
212fn has_derived_trait(attrs: &[Attribute], trait_name: &str) -> bool {
213 attrs
214 .iter()
215 .filter(|attr| attr.path().is_ident("derive"))
216 .filter_map(|attr| {
217 attr.parse_args_with(
218 syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated,
219 )
220 .ok()
221 })
222 .flat_map(|paths| paths.into_iter())
223 .any(|path| {
224 path.segments
225 .last()
226 .map(|seg| seg.ident == trait_name)
227 .unwrap_or(false)
228 })
229}
230
231fn add_missing_derives(attrs: &mut Vec<Attribute>, required: &[&str]) {
232 let mut missing = Vec::new();
233 for name in required {
234 if has_derived_trait(attrs, name) {
235 continue;
236 }
237 let path: syn::Path = match *name {
238 "Serialize" => syn::parse_quote!(::typeshift::serde::Serialize),
239 "Deserialize" => syn::parse_quote!(::typeshift::serde::Deserialize),
240 "Validate" => syn::parse_quote!(::typeshift::validator::Validate),
241 "JsonSchema" => syn::parse_quote!(::typeshift::schemars::JsonSchema),
242 _ => continue,
243 };
244 missing.push(path);
245 }
246
247 if !missing.is_empty() {
248 let insert_at = attrs
249 .iter()
250 .rposition(|attr| attr.path().is_ident("derive"))
251 .map(|index| index + 1)
252 .unwrap_or(0);
253 attrs.insert(insert_at, syn::parse_quote!(#[derive(#(#missing),*)]));
254 }
255}
256
257fn ensure_attr(attrs: &mut Vec<Attribute>, name: &str, args: &str) {
258 let path = syn::Ident::new(name, proc_macro2::Span::call_site());
259 let args: proc_macro2::TokenStream = match args.parse() {
260 Ok(args) => args,
261 Err(_) => return,
262 };
263
264 let has_crate_arg = attrs
265 .iter()
266 .any(|attr| attr.path().is_ident(name) && attr_has_crate_arg(attr));
267
268 if !has_crate_arg {
269 attrs.push(syn::parse_quote!(#[#path(#args)]));
270 }
271}
272
273fn attr_has_crate_arg(attr: &Attribute) -> bool {
274 attr.parse_args_with(syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated)
275 .map(|metas| {
276 metas.into_iter().any(|meta| {
277 if let syn::Meta::NameValue(name_value) = meta {
278 return name_value.path.is_ident("crate");
279 }
280 false
281 })
282 })
283 .unwrap_or(false)
284}