type_macro_derive_tricks/
lib.rs1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use rand::{distributions::Alphanumeric, Rng};
6use std::collections::HashMap;
7use syn::{
8 parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Fields, Generics, Ident, Type,
9};
10use template_quote::quote;
11
12#[proc_macro_attribute]
22pub fn macro_derive(args: TokenStream, input: TokenStream) -> TokenStream {
23 let derive_traits = parse_derive_traits(args);
24 let input = parse_macro_input!(input as DeriveInput);
25
26 let expanded = impl_type_macro_derive_tricks(&derive_traits, &input);
27 TokenStream::from(expanded)
28}
29
30fn parse_derive_traits(args: TokenStream) -> Vec<syn::Path> {
31 let args = TokenStream2::from(args);
32
33 if args.is_empty() {
34 return Vec::new();
35 }
36
37 let mut traits = Vec::new();
39 let mut current_trait = String::new();
40
41 for token in args.into_iter() {
42 match token {
43 proc_macro2::TokenTree::Punct(punct) if punct.as_char() == ',' => {
44 if !current_trait.is_empty() {
45 if let Ok(path) = syn::parse_str::<syn::Path>(current_trait.trim()) {
46 traits.push(path);
47 }
48 current_trait.clear();
49 }
50 }
51 _ => {
52 current_trait.push_str(&token.to_string());
53 }
54 }
55 }
56
57 if !current_trait.is_empty() {
59 if let Ok(path) = syn::parse_str::<syn::Path>(current_trait.trim()) {
60 traits.push(path);
61 }
62 }
63
64 traits
65}
66
67fn impl_type_macro_derive_tricks(derive_traits: &[syn::Path], input: &DeriveInput) -> TokenStream2 {
68 let mut macro_types = HashMap::new();
69 let mut type_aliases = Vec::new();
70
71 collect_macro_types(&input.data, &input.generics, &mut macro_types);
73
74 for (macro_type, alias_name) in ¯o_types {
76 let used_generic_params = get_used_generic_params(macro_type, &input.generics);
79
80 let alias = if used_generic_params.is_empty() {
81 quote! {
82 #[doc(hidden)]
83 type #alias_name = #macro_type;
84 }
85 } else {
86 let filtered_generics = create_filtered_generics(&used_generic_params)
88 .params
89 .into_iter()
90 .map(|mut param| {
91 match &mut param {
92 syn::GenericParam::Type(tp) => {
93 tp.eq_token = None;
94 tp.default = None;
95 }
96 syn::GenericParam::Const(cp) => {
97 cp.eq_token = None;
98 cp.default = None;
99 }
100 _ => (),
101 }
102 param
103 })
104 .collect::<Punctuated<_, syn::Token![,]>>();
105 quote! {
106 #[doc(hidden)]
107 type #alias_name <#filtered_generics> = #macro_type;
108 }
109 };
110 type_aliases.push(alias);
111 }
112
113 let transformed_input = transform_input(input, ¯o_types);
115
116 let derive_attrs = if !derive_traits.is_empty() {
118 let traits: Vec<_> = derive_traits.iter().collect();
119 quote! {
120 #[derive(#(#traits),*)]
121 }
122 } else {
123 quote! {}
124 };
125
126 quote! {
128 #(#type_aliases)*
129
130 #derive_attrs
131 #transformed_input
132 }
133}
134
135fn collect_macro_types(data: &Data, generics: &Generics, macro_types: &mut HashMap<Type, Ident>) {
136 match data {
137 Data::Struct(data_struct) => {
138 collect_macro_types_from_fields(&data_struct.fields, generics, macro_types);
139 }
140 Data::Enum(data_enum) => {
141 for variant in &data_enum.variants {
142 collect_macro_types_from_fields(&variant.fields, generics, macro_types);
143 }
144 }
145 Data::Union(data_union) => {
146 collect_macro_types_from_fields(
147 &Fields::Named(data_union.fields.clone()),
148 generics,
149 macro_types,
150 );
151 }
152 }
153}
154
155fn collect_macro_types_from_fields(
156 fields: &Fields,
157 generics: &Generics,
158 macro_types: &mut HashMap<Type, Ident>,
159) {
160 match fields {
161 Fields::Named(fields) => {
162 for field in &fields.named {
163 collect_macro_types_from_type(&field.ty, generics, macro_types);
164 }
165 }
166 Fields::Unnamed(fields) => {
167 for field in &fields.unnamed {
168 collect_macro_types_from_type(&field.ty, generics, macro_types);
169 }
170 }
171 Fields::Unit => {}
172 }
173}
174
175fn collect_macro_types_from_type(
176 ty: &Type,
177 _generics: &Generics,
178 macro_types: &mut HashMap<Type, Ident>,
179) {
180 if let Type::Macro(_) = ty {
182 if !macro_types.contains_key(ty) {
183 let alias_name = generate_random_type_name();
184 macro_types.insert(ty.clone(), alias_name);
185 }
186 return;
187 }
188
189 match ty {
191 Type::Path(type_path) => {
192 for segment in &type_path.path.segments {
193 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
194 for arg in &args.args {
195 if let syn::GenericArgument::Type(nested_ty) = arg {
196 collect_macro_types_from_type(nested_ty, _generics, macro_types);
197 }
198 }
199 }
200 }
201 }
202 Type::Array(type_array) => {
203 collect_macro_types_from_type(&type_array.elem, _generics, macro_types);
204 }
205 Type::Ptr(type_ptr) => {
206 collect_macro_types_from_type(&type_ptr.elem, _generics, macro_types);
207 }
208 Type::Reference(type_ref) => {
209 collect_macro_types_from_type(&type_ref.elem, _generics, macro_types);
210 }
211 Type::Slice(type_slice) => {
212 collect_macro_types_from_type(&type_slice.elem, _generics, macro_types);
213 }
214 Type::Tuple(type_tuple) => {
215 for elem in &type_tuple.elems {
216 collect_macro_types_from_type(elem, _generics, macro_types);
217 }
218 }
219 _ => {}
220 }
221}
222
223fn generate_random_type_name() -> Ident {
224 let random_suffix: String = rand::thread_rng()
225 .sample_iter(&Alphanumeric)
226 .take(12)
227 .map(char::from)
228 .collect();
229
230 Ident::new(
231 &format!("__TypeMacroAlias{}", random_suffix),
232 proc_macro2::Span::call_site(),
233 )
234}
235
236fn get_used_generic_params(macro_type: &Type, generics: &Generics) -> Vec<syn::GenericParam> {
237 let mut used_params = Vec::new();
239
240 if let Type::Macro(type_macro) = macro_type {
241 let macro_tokens = &type_macro.mac.tokens;
242
243 for param in &generics.params {
244 let param_name = match param {
245 syn::GenericParam::Type(type_param) => type_param.ident.to_string(),
246 syn::GenericParam::Lifetime(lifetime_param) => lifetime_param.lifetime.to_string(),
247 syn::GenericParam::Const(const_param) => const_param.ident.to_string(),
248 };
249
250 if is_generic_param_used_in_token_stream(macro_tokens, ¶m_name) {
252 used_params.push(param.clone());
253 }
254 }
255 }
256
257 used_params
258}
259
260fn is_generic_param_used_in_token_stream(
261 tokens: &proc_macro2::TokenStream,
262 identifier: &str,
263) -> bool {
264 use proc_macro2::TokenTree;
265
266 let tokens_vec: Vec<TokenTree> = tokens.clone().into_iter().collect();
267
268 for (i, token) in tokens_vec.iter().enumerate() {
269 match token {
270 TokenTree::Ident(ident) => {
271 if *ident == identifier {
273 return true;
274 }
275 }
276 TokenTree::Group(group) => {
277 if is_generic_param_used_in_token_stream(&group.stream(), identifier) {
279 return true;
280 }
281 }
282 TokenTree::Punct(punct) => {
283 if punct.as_char() == '\'' && i + 1 < tokens_vec.len() {
285 if let TokenTree::Ident(ident) = &tokens_vec[i + 1] {
286 let lifetime = format!("'{}", ident);
287 if lifetime == identifier {
288 return true;
289 }
290 }
291 }
292 }
293 TokenTree::Literal(_) => {
294 continue;
296 }
297 }
298 }
299
300 false
301}
302
303fn create_filtered_generics(used_params: &[syn::GenericParam]) -> syn::Generics {
304 let mut generics = syn::Generics::default();
306
307 for param in used_params {
308 generics.params.push(param.clone());
309 }
310
311 generics
312}
313
314fn transform_input(input: &DeriveInput, macro_types: &HashMap<Type, Ident>) -> DeriveInput {
315 let mut transformed = input.clone();
316
317 match &mut transformed.data {
318 Data::Struct(data_struct) => {
319 transform_fields(&mut data_struct.fields, macro_types, &input.generics);
320 }
321 Data::Enum(data_enum) => {
322 for variant in &mut data_enum.variants {
323 transform_fields(&mut variant.fields, macro_types, &input.generics);
324 }
325 }
326 Data::Union(data_union) => {
327 let mut fields = Fields::Named(data_union.fields.clone());
328 transform_fields(&mut fields, macro_types, &input.generics);
329 if let Fields::Named(named_fields) = fields {
330 data_union.fields = named_fields;
331 }
332 }
333 }
334
335 transformed
336}
337
338fn transform_fields(fields: &mut Fields, macro_types: &HashMap<Type, Ident>, generics: &Generics) {
339 match fields {
340 Fields::Named(fields) => {
341 for field in &mut fields.named {
342 transform_type(&mut field.ty, macro_types, generics);
343 }
344 }
345 Fields::Unnamed(fields) => {
346 for field in &mut fields.unnamed {
347 transform_type(&mut field.ty, macro_types, generics);
348 }
349 }
350 Fields::Unit => {}
351 }
352}
353
354fn transform_type(ty: &mut Type, macro_types: &HashMap<Type, Ident>, generics: &Generics) {
355 if let Type::Macro(_) = ty {
357 if let Some(alias) = macro_types.get(ty) {
359 let used_generic_params = get_used_generic_params(ty, generics);
360
361 if used_generic_params.is_empty() {
362 *ty = syn::parse_quote!(#alias);
363 } else {
364 let filtered_generics = create_filtered_generics(&used_generic_params);
366 let (_, ty_generics, _) = filtered_generics.split_for_impl();
367 *ty = syn::parse_quote!(#alias #ty_generics);
368 }
369 }
370 return;
371 }
372
373 match ty {
375 Type::Path(type_path) => {
376 for segment in &mut type_path.path.segments {
377 if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
378 for arg in &mut args.args {
379 if let syn::GenericArgument::Type(nested_ty) = arg {
380 transform_type(nested_ty, macro_types, generics);
381 }
382 }
383 }
384 }
385 }
386 Type::Array(type_array) => {
387 transform_type(&mut type_array.elem, macro_types, generics);
388 }
389 Type::Ptr(type_ptr) => {
390 transform_type(&mut type_ptr.elem, macro_types, generics);
391 }
392 Type::Reference(type_ref) => {
393 transform_type(&mut type_ref.elem, macro_types, generics);
394 }
395 Type::Slice(type_slice) => {
396 transform_type(&mut type_slice.elem, macro_types, generics);
397 }
398 Type::Tuple(type_tuple) => {
399 for elem in &mut type_tuple.elems {
400 transform_type(elem, macro_types, generics);
401 }
402 }
403 _ => {}
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_generate_random_type_name() {
413 let name1 = generate_random_type_name();
414 let name2 = generate_random_type_name();
415
416 assert_ne!(name1, name2);
417 assert!(name1.to_string().starts_with("__TypeMacroAlias"));
418 assert!(name2.to_string().starts_with("__TypeMacroAlias"));
419 }
420}