type_macro_derive_tricks/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use proc_macro2::TokenStream as TokenStream2;
5use rand::{distributions::Alphanumeric, Rng};
6use std::collections::HashMap;
7use syn::{
8    parse_macro_input, punctuated::Punctuated, Data, DeriveInput, Fields, Generics, Ident, Type,
9};
10use template_quote::quote;
11
12/// Main procedural macro that handles types with macros in type positions
13///
14/// Usage: `#[macro_derive(Trait1, Trait2, ...)]`
15///
16/// This macro:
17/// 1. Identifies all macro invocations in type positions
18/// 2. Generates unique type aliases for each macro type
19/// 3. Replaces the macro types with the aliases
20/// 4. Applies the specified derive traits to the transformed type
21#[proc_macro_attribute]
22pub fn macro_derive(args: TokenStream, input: TokenStream) -> TokenStream {
23    let derive_traits = parse_derive_traits(args);
24    let input = parse_macro_input!(input as DeriveInput);
25
26    let expanded = impl_type_macro_derive_tricks(&derive_traits, &input);
27    TokenStream::from(expanded)
28}
29
30fn parse_derive_traits(args: TokenStream) -> Vec<syn::Path> {
31    let args = TokenStream2::from(args);
32
33    if args.is_empty() {
34        return Vec::new();
35    }
36
37    // Parse comma-separated list of trait names
38    let mut traits = Vec::new();
39    let mut current_trait = String::new();
40
41    for token in args.into_iter() {
42        match token {
43            proc_macro2::TokenTree::Punct(punct) if punct.as_char() == ',' => {
44                if !current_trait.is_empty() {
45                    if let Ok(path) = syn::parse_str::<syn::Path>(current_trait.trim()) {
46                        traits.push(path);
47                    }
48                    current_trait.clear();
49                }
50            }
51            _ => {
52                current_trait.push_str(&token.to_string());
53            }
54        }
55    }
56
57    // Don't forget the last trait
58    if !current_trait.is_empty() {
59        if let Ok(path) = syn::parse_str::<syn::Path>(current_trait.trim()) {
60            traits.push(path);
61        }
62    }
63
64    traits
65}
66
67fn impl_type_macro_derive_tricks(derive_traits: &[syn::Path], input: &DeriveInput) -> TokenStream2 {
68    let mut macro_types = HashMap::new();
69    let mut type_aliases = Vec::new();
70
71    // Step 1: Collect all macro types and generate aliases
72    collect_macro_types(&input.data, &input.generics, &mut macro_types);
73
74    // Step 2: Generate type aliases
75    for (macro_type, alias_name) in &macro_types {
76        // Generate type aliases with only the specific generic parameters used by the macro
77        // and add #[doc(hidden)] to hide them from documentation
78        let used_generic_params = get_used_generic_params(macro_type, &input.generics);
79
80        let alias = if used_generic_params.is_empty() {
81            quote! {
82                #[doc(hidden)]
83                type #alias_name = #macro_type;
84            }
85        } else {
86            // Create a filtered Generics struct with only the used parameters
87            let filtered_generics = create_filtered_generics(&used_generic_params)
88                .params
89                .into_iter()
90                .map(|mut param| {
91                    match &mut param {
92                        syn::GenericParam::Type(tp) => {
93                            tp.eq_token = None;
94                            tp.default = None;
95                        }
96                        syn::GenericParam::Const(cp) => {
97                            cp.eq_token = None;
98                            cp.default = None;
99                        }
100                        _ => (),
101                    }
102                    param
103                })
104                .collect::<Punctuated<_, syn::Token![,]>>();
105            quote! {
106                #[doc(hidden)]
107                type #alias_name <#filtered_generics> = #macro_type;
108            }
109        };
110        type_aliases.push(alias);
111    }
112
113    // Step 3: Transform the original type by replacing macro types with aliases
114    let transformed_input = transform_input(input, &macro_types);
115
116    // Step 4: Generate derive attribute
117    let derive_attrs = if !derive_traits.is_empty() {
118        let traits: Vec<_> = derive_traits.iter().collect();
119        quote! {
120            #[derive(#(#traits),*)]
121        }
122    } else {
123        quote! {}
124    };
125
126    // Step 5: Combine everything
127    quote! {
128        #(#type_aliases)*
129
130        #derive_attrs
131        #transformed_input
132    }
133}
134
135fn collect_macro_types(data: &Data, generics: &Generics, macro_types: &mut HashMap<Type, Ident>) {
136    match data {
137        Data::Struct(data_struct) => {
138            collect_macro_types_from_fields(&data_struct.fields, generics, macro_types);
139        }
140        Data::Enum(data_enum) => {
141            for variant in &data_enum.variants {
142                collect_macro_types_from_fields(&variant.fields, generics, macro_types);
143            }
144        }
145        Data::Union(data_union) => {
146            collect_macro_types_from_fields(
147                &Fields::Named(data_union.fields.clone()),
148                generics,
149                macro_types,
150            );
151        }
152    }
153}
154
155fn collect_macro_types_from_fields(
156    fields: &Fields,
157    generics: &Generics,
158    macro_types: &mut HashMap<Type, Ident>,
159) {
160    match fields {
161        Fields::Named(fields) => {
162            for field in &fields.named {
163                collect_macro_types_from_type(&field.ty, generics, macro_types);
164            }
165        }
166        Fields::Unnamed(fields) => {
167            for field in &fields.unnamed {
168                collect_macro_types_from_type(&field.ty, generics, macro_types);
169            }
170        }
171        Fields::Unit => {}
172    }
173}
174
175fn collect_macro_types_from_type(
176    ty: &Type,
177    _generics: &Generics,
178    macro_types: &mut HashMap<Type, Ident>,
179) {
180    // Handle macro types directly - create aliases only for actual macro invocations
181    if let Type::Macro(_) = ty {
182        if !macro_types.contains_key(ty) {
183            let alias_name = generate_random_type_name();
184            macro_types.insert(ty.clone(), alias_name);
185        }
186        return;
187    }
188
189    // Recursively check all nested types for macro invocations
190    match ty {
191        Type::Path(type_path) => {
192            for segment in &type_path.path.segments {
193                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
194                    for arg in &args.args {
195                        if let syn::GenericArgument::Type(nested_ty) = arg {
196                            collect_macro_types_from_type(nested_ty, _generics, macro_types);
197                        }
198                    }
199                }
200            }
201        }
202        Type::Array(type_array) => {
203            collect_macro_types_from_type(&type_array.elem, _generics, macro_types);
204        }
205        Type::Ptr(type_ptr) => {
206            collect_macro_types_from_type(&type_ptr.elem, _generics, macro_types);
207        }
208        Type::Reference(type_ref) => {
209            collect_macro_types_from_type(&type_ref.elem, _generics, macro_types);
210        }
211        Type::Slice(type_slice) => {
212            collect_macro_types_from_type(&type_slice.elem, _generics, macro_types);
213        }
214        Type::Tuple(type_tuple) => {
215            for elem in &type_tuple.elems {
216                collect_macro_types_from_type(elem, _generics, macro_types);
217            }
218        }
219        _ => {}
220    }
221}
222
223fn generate_random_type_name() -> Ident {
224    let random_suffix: String = rand::thread_rng()
225        .sample_iter(&Alphanumeric)
226        .take(12)
227        .map(char::from)
228        .collect();
229
230    Ident::new(
231        &format!("__TypeMacroAlias{}", random_suffix),
232        proc_macro2::Span::call_site(),
233    )
234}
235
236fn get_used_generic_params(macro_type: &Type, generics: &Generics) -> Vec<syn::GenericParam> {
237    // Analyze which specific generic parameters are used in the macro type
238    let mut used_params = Vec::new();
239
240    if let Type::Macro(type_macro) = macro_type {
241        let macro_tokens = &type_macro.mac.tokens;
242
243        for param in &generics.params {
244            let param_name = match param {
245                syn::GenericParam::Type(type_param) => type_param.ident.to_string(),
246                syn::GenericParam::Lifetime(lifetime_param) => lifetime_param.lifetime.to_string(),
247                syn::GenericParam::Const(const_param) => const_param.ident.to_string(),
248            };
249
250            // Use the improved token search that handles nested structures
251            if is_generic_param_used_in_token_stream(macro_tokens, &param_name) {
252                used_params.push(param.clone());
253            }
254        }
255    }
256
257    used_params
258}
259
260fn is_generic_param_used_in_token_stream(
261    tokens: &proc_macro2::TokenStream,
262    identifier: &str,
263) -> bool {
264    use proc_macro2::TokenTree;
265
266    let tokens_vec: Vec<TokenTree> = tokens.clone().into_iter().collect();
267
268    for (i, token) in tokens_vec.iter().enumerate() {
269        match token {
270            TokenTree::Ident(ident) => {
271                // Handle regular type parameters and const parameters
272                if *ident == identifier {
273                    return true;
274                }
275            }
276            TokenTree::Group(group) => {
277                // Recursively search inside groups (brackets, braces, parentheses)
278                if is_generic_param_used_in_token_stream(&group.stream(), identifier) {
279                    return true;
280                }
281            }
282            TokenTree::Punct(punct) => {
283                // Handle lifetimes: look for ' followed by an identifier
284                if punct.as_char() == '\'' && i + 1 < tokens_vec.len() {
285                    if let TokenTree::Ident(ident) = &tokens_vec[i + 1] {
286                        let lifetime = format!("'{}", ident);
287                        if lifetime == identifier {
288                            return true;
289                        }
290                    }
291                }
292            }
293            TokenTree::Literal(_) => {
294                // Literals don't contain type parameters
295                continue;
296            }
297        }
298    }
299
300    false
301}
302
303fn create_filtered_generics(used_params: &[syn::GenericParam]) -> syn::Generics {
304    // Create a new Generics struct containing only the used parameters
305    let mut generics = syn::Generics::default();
306
307    for param in used_params {
308        generics.params.push(param.clone());
309    }
310
311    generics
312}
313
314fn transform_input(input: &DeriveInput, macro_types: &HashMap<Type, Ident>) -> DeriveInput {
315    let mut transformed = input.clone();
316
317    match &mut transformed.data {
318        Data::Struct(data_struct) => {
319            transform_fields(&mut data_struct.fields, macro_types, &input.generics);
320        }
321        Data::Enum(data_enum) => {
322            for variant in &mut data_enum.variants {
323                transform_fields(&mut variant.fields, macro_types, &input.generics);
324            }
325        }
326        Data::Union(data_union) => {
327            let mut fields = Fields::Named(data_union.fields.clone());
328            transform_fields(&mut fields, macro_types, &input.generics);
329            if let Fields::Named(named_fields) = fields {
330                data_union.fields = named_fields;
331            }
332        }
333    }
334
335    transformed
336}
337
338fn transform_fields(fields: &mut Fields, macro_types: &HashMap<Type, Ident>, generics: &Generics) {
339    match fields {
340        Fields::Named(fields) => {
341            for field in &mut fields.named {
342                transform_type(&mut field.ty, macro_types, generics);
343            }
344        }
345        Fields::Unnamed(fields) => {
346            for field in &mut fields.unnamed {
347                transform_type(&mut field.ty, macro_types, generics);
348            }
349        }
350        Fields::Unit => {}
351    }
352}
353
354fn transform_type(ty: &mut Type, macro_types: &HashMap<Type, Ident>, generics: &Generics) {
355    // Handle macro types directly
356    if let Type::Macro(_) = ty {
357        // Check if this macro type has an alias
358        if let Some(alias) = macro_types.get(ty) {
359            let used_generic_params = get_used_generic_params(ty, generics);
360
361            if used_generic_params.is_empty() {
362                *ty = syn::parse_quote!(#alias);
363            } else {
364                // Create filtered generics and use them
365                let filtered_generics = create_filtered_generics(&used_generic_params);
366                let (_, ty_generics, _) = filtered_generics.split_for_impl();
367                *ty = syn::parse_quote!(#alias #ty_generics);
368            }
369        }
370        return;
371    }
372
373    // Recursively transform nested types, looking for macro parts within them
374    match ty {
375        Type::Path(type_path) => {
376            for segment in &mut type_path.path.segments {
377                if let syn::PathArguments::AngleBracketed(args) = &mut segment.arguments {
378                    for arg in &mut args.args {
379                        if let syn::GenericArgument::Type(nested_ty) = arg {
380                            transform_type(nested_ty, macro_types, generics);
381                        }
382                    }
383                }
384            }
385        }
386        Type::Array(type_array) => {
387            transform_type(&mut type_array.elem, macro_types, generics);
388        }
389        Type::Ptr(type_ptr) => {
390            transform_type(&mut type_ptr.elem, macro_types, generics);
391        }
392        Type::Reference(type_ref) => {
393            transform_type(&mut type_ref.elem, macro_types, generics);
394        }
395        Type::Slice(type_slice) => {
396            transform_type(&mut type_slice.elem, macro_types, generics);
397        }
398        Type::Tuple(type_tuple) => {
399            for elem in &mut type_tuple.elems {
400                transform_type(elem, macro_types, generics);
401            }
402        }
403        _ => {}
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn test_generate_random_type_name() {
413        let name1 = generate_random_type_name();
414        let name2 = generate_random_type_name();
415
416        assert_ne!(name1, name2);
417        assert!(name1.to_string().starts_with("__TypeMacroAlias"));
418        assert!(name2.to_string().starts_with("__TypeMacroAlias"));
419    }
420}