trait_variable/
lib.rs

1mod path_utils;
2mod trait_item;
3mod trait_utils;
4
5use path_utils::PathFinder;
6use proc_macro2::TokenStream;
7
8#[allow(unused)]
9use quote::{quote, ToTokens};
10
11use syn::visit::{self, Visit};
12use syn::{
13    braced, token, AngleBracketedGenericArguments, GenericArgument, Generics, PathArguments, Type,
14    TypeParamBound, TypePath, Visibility, WhereClause,
15};
16use syn::{
17    parse::{Parse, ParseStream},
18    parse_macro_input,
19    punctuated::Punctuated,
20    Ident, Token, TraitItem,
21};
22use trait_item::refine_trait_items;
23
24struct GenericTypeVisitor {
25    generics: Vec<String>,
26}
27impl GenericTypeVisitor {
28    fn is_single_upper_letter(&self, ident_str: &str) -> bool {
29        ident_str.len() == 1 && ident_str.chars().next().unwrap().is_uppercase()
30    }
31}
32impl<'ast> Visit<'ast> for GenericTypeVisitor {
33    fn visit_type(&mut self, i: &'ast Type) {
34        if let Type::Path(TypePath { path, .. }) = i {
35            if let Some(PathArguments::AngleBracketed(AngleBracketedGenericArguments {
36                args,
37                ..
38            })) = path.segments.last().map(|seg| &seg.arguments)
39            {
40                for arg in args {
41                    if let GenericArgument::Type(Type::Path(tp)) = arg {
42                        if let Some(ident) = tp.path.get_ident() {
43                            let ident_str = ident.to_string();
44                            if self.is_single_upper_letter(&ident_str)
45                                && !self.generics.contains(&ident_str)
46                            {
47                                self.generics.push(ident_str);
48                            }
49                        }
50                    }
51                }
52            } else if let Some(seg) = path.segments.last() {
53                let ident_str = seg.ident.to_string();
54                if self.is_single_upper_letter(&ident_str) && !self.generics.contains(&ident_str) {
55                    self.generics.push(ident_str);
56                }
57            }
58        }
59        // Continue the traversal to nested types
60        visit::visit_type(self, i);
61    }
62}
63#[test]
64fn test_generic_type_visitor() {
65    // case 1
66    let code = quote! { V }; // the quoted type is invalid, but ok for test
67    let syntax_tree: syn::Type = syn::parse2(code).unwrap();
68    let mut visitor = GenericTypeVisitor {
69        generics: Vec::new(),
70    };
71    visitor.visit_type(&syntax_tree);
72
73    assert_eq!(visitor.generics, vec!["V"]);
74    // case 2
75    let code = quote! { Vec<T, HashMap<K, V>> }; // the quoted type is invalid, but ok for test
76    let syntax_tree: syn::Type = syn::parse2(code).unwrap();
77    let mut visitor = GenericTypeVisitor {
78        generics: Vec::new(),
79    };
80    visitor.visit_type(&syntax_tree);
81
82    assert_eq!(visitor.generics, vec!["T", "K", "V"]);
83}
84
85/// Define the struct to represent a single trait variable field.
86struct TraitVarField {
87    var_vis: Visibility,
88    var_name: Ident,
89    type_name: Type,
90    type_generics: Vec<String>,
91}
92impl Parse for TraitVarField {
93    fn parse(input: ParseStream) -> syn::Result<Self> {
94        let var_vis: Visibility = input.parse().expect("Failed to Parse to `var_vis`");
95        let var_name: Ident = input.parse().expect("Failed to Parse to `var_name`");
96        let _: Token![:] = input.parse().expect("Failed to Parse to `:`");
97        let type_name: Type = input.parse().expect("Failed to Parse to `type_name`");
98        let type_generics = {
99            let mut visitor = GenericTypeVisitor {
100                generics: Vec::new(),
101            };
102            visitor.visit_type(&type_name);
103            visitor.generics
104        };
105        Ok(TraitVarField {
106            var_vis,
107            var_name,
108            type_name,
109            type_generics,
110        })
111    }
112}
113#[test]
114fn test_trait_var_field() {
115    let raw_code = quote! { pub var_name: Vec<T, HashMap<K, V>> };
116    let parsed =
117        syn::parse2::<TraitVarField>(raw_code).expect("Failed to parse to `TraitVarField`");
118
119    assert!(
120        matches!(parsed.var_vis, Visibility::Public(_)),
121        "Visibility is not public"
122    );
123    assert_eq!(parsed.var_name.to_string(), "var_name".to_string());
124    assert_eq!(
125        parsed.type_name.to_token_stream().to_string(),
126        "Vec < T , HashMap < K , V > >".to_string()
127    );
128    assert_eq!(
129        parsed.type_generics,
130        vec!["T".to_string(), "K".to_string(), "V".to_string()]
131    );
132}
133
134struct TraitInput {
135    trait_vis: Visibility,
136    _trait_token: Token![trait],
137    trait_name: Ident,
138    trait_bounds: Option<Generics>, // optional generic parameters for the trait
139    explicit_parent_traits: Option<Punctuated<TypeParamBound, Token![+]>>, // explicit parent traits
140    where_clause: Option<WhereClause>, // optional where clause for the trait
141    _brace_token: token::Brace,
142    trait_variables: Vec<TraitVarField>,
143    trait_items: Vec<TraitItem>,
144}
145
146impl Parse for TraitInput {
147    fn parse(input: ParseStream) -> syn::Result<Self> {
148        let content;
149
150        Ok(TraitInput {
151            trait_vis: input.parse()?,
152            _trait_token: input.parse()?,
153            trait_name: input.parse()?,
154            trait_bounds: if input.peek(Token![<]) {
155                Some(input.parse()?) // Use the parse method to parse the generics
156            } else {
157                None
158            },
159            explicit_parent_traits: if input.peek(Token![:]) {
160                input.parse::<Token![:]>()?;
161                let mut parent_traits = Punctuated::new();
162                while !input.peek(Token![where]) && !input.peek(token::Brace) {
163                    parent_traits.push_value(input.parse()?);
164                    if input.peek(Token![+]) {
165                        parent_traits.push_punct(input.parse()?);
166                    } else {
167                        break;
168                    }
169                }
170                Some(parent_traits)
171            } else {
172                None
173            },
174            where_clause: if input.peek(syn::token::Where) {
175                Some(input.parse()?)
176            } else {
177                None
178            },
179            _brace_token: braced!(content in input),
180            // Parse all variable declarations until a method or end of input is encountered
181            trait_variables: {
182                let mut v = Vec::new();
183                while !content.peek(Token![type])
184                    && !content.peek(Token![const])
185                    && !content.peek(Token![fn])
186                    && !content.is_empty()
187                {
188                    v.push(content.call(TraitVarField::parse)?);
189                    let _: Token![;] = content.parse()?;
190                }
191                v
192            },
193            trait_items: {
194                let mut items = Vec::new();
195                while !content.is_empty() {
196                    items.push(content.parse()?);
197                }
198                items
199            },
200        })
201    }
202}
203
204#[test]
205fn test_trait_input() {
206    let raw_code = quote! {
207        pub trait MyTrait {
208            x: Vec<T, HashMap<K, V>>;
209            pub y: bool;
210
211            fn print_x(&self){
212                println!("x: `{}`", self.x);
213            }
214            fn print_y(&self){
215                println!("y: `{}`", self.y);
216            }
217            fn print_all(&self);
218        }
219    };
220    let parsed = syn::parse2::<TraitInput>(raw_code).unwrap();
221
222    assert!(matches!(parsed.trait_vis, Visibility::Public(_)));
223    assert_eq!(parsed.trait_name.to_string(), "MyTrait".to_string());
224    assert!(parsed.trait_bounds.is_none());
225    assert!(parsed.explicit_parent_traits.is_none());
226    assert!(parsed.where_clause.is_none());
227    assert_eq!(parsed.trait_variables.len(), 2);
228    assert_eq!(
229        parsed.trait_variables[0].var_name.to_string(),
230        "x".to_string()
231    );
232    assert_eq!(
233        parsed.trait_variables[1].var_name.to_string(),
234        "y".to_string()
235    );
236    assert_eq!(parsed.trait_items.len(), 3);
237    assert_eq!(
238        parsed.trait_items[0].to_token_stream().to_string(),
239        "fn print_x (& self) { println ! (\"x: `{}`\" , self . x) ; }".to_string()
240    );
241    assert_eq!(
242        parsed.trait_items[1].to_token_stream().to_string(),
243        "fn print_y (& self) { println ! (\"y: `{}`\" , self . y) ; }".to_string()
244    );
245    assert_eq!(
246        parsed.trait_items[2].to_token_stream().to_string(),
247        "fn print_all (& self) ;".to_string()
248    );
249}
250
251/// functional macro: used to generate code for a trait with variable fields
252#[proc_macro]
253pub fn trait_variable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
254    let TraitInput {
255        trait_vis,
256        trait_name,
257        trait_bounds,
258        explicit_parent_traits,
259        where_clause,
260        trait_variables,
261        trait_items,
262        ..
263    } = parse_macro_input!(input as TraitInput);
264
265    // 1.1 get parent trait name
266    let hidden_parent_trait_name = Ident::new(&format!("_{}", trait_name), trait_name.span());
267    // 1.2 get trait declarative macro name
268    let trait_decl_macro_name =
269        Ident::new(&format!("{}_for_struct", trait_name), trait_name.span());
270
271    // 2.1 generate parent trait methods declaration
272    let hidden_parent_trait_methods_signatures = trait_variables.iter().map(
273        |TraitVarField {
274             var_name,
275             type_name,
276             ..
277         }| {
278            let method_name = Ident::new(&format!("_{}", var_name), var_name.span());
279            let method_name_mut = Ident::new(&format!("_{}_mut", var_name), var_name.span());
280            quote! {
281                fn #method_name(&self) -> &#type_name;
282                fn #method_name_mut(&mut self) -> &mut #type_name;
283            }
284        },
285    );
286    // 2.2 generate trait variable fields definition for structs later
287    let trait_fields_in_struct = trait_variables.iter().map(
288        |TraitVarField {
289             var_vis,
290             var_name,
291             type_name,
292             ..
293         }| {
294            quote! {
295                #var_vis #var_name: #type_name,
296            }
297        },
298    );
299    // 2.3 generate parent trait methods implementation for struct
300    let parent_trait_methods_impls_in_struct = trait_variables.iter().map(
301        |TraitVarField {
302             var_name,
303             type_name,
304             ..
305         }| {
306            let method_name = Ident::new(&format!("_{}", var_name), var_name.span());
307            let method_name_mut = Ident::new(&format!("_{}_mut", var_name), var_name.span());
308            quote! {
309                fn #method_name(&self) -> &#type_name{
310                    &self.#var_name
311                }
312                fn #method_name_mut(&mut self) -> &mut #type_name{
313                    &mut self.#var_name
314                }
315            }
316        },
317    );
318    // 2.4 check if the parent trait has generic type
319    let hidden_parent_trait_bounds = {
320        let mut generic_types = Vec::new();
321        for trait_var in trait_variables.iter() {
322            for generic in &trait_var.type_generics {
323                let generic_ident = syn::Ident::new(generic, proc_macro2::Span::call_site());
324                if !generic_types.contains(&generic_ident) {
325                    generic_types.push(generic_ident);
326                }
327            }
328        }
329        if !generic_types.is_empty() {
330            quote! { <#(#generic_types),*> }
331        } else {
332            TokenStream::new()
333        }
334    };
335
336    // 3. refine the body of methods from the original trait
337    let trait_items = refine_trait_items(trait_items);
338
339    // 4. expand the trait code
340    let hidden_parent_trait_with_bounds =
341        quote! {#hidden_parent_trait_name #hidden_parent_trait_bounds};
342    let expanded_trait_code = quote! {
343        #trait_vis trait #hidden_parent_trait_with_bounds {
344            #(#hidden_parent_trait_methods_signatures)*
345        }
346        #trait_vis trait #trait_name #trait_bounds: #hidden_parent_trait_with_bounds + #explicit_parent_traits #where_clause {
347            #(#trait_items)*
348        }
349    };
350
351    // 5. generate the hidden declarative macro for target struct
352    let declarative_macro_code = quote! {
353        #[doc(hidden)]
354        #[macro_export] // it is ok to always export the declarative macro
355        macro_rules! #trait_decl_macro_name { // NOTE: the reexpanded macro is used for rust struct only
356            (
357                $(#[$struct_attr:meta])* // NOTE: make sure the style is consistent with that in arm 2 output
358                $vis:vis struct $struct_name:ident
359                $(<$($generic_param:ident),* $(, $generic_lifetime:lifetime)* $(,)? >)?
360                // TODO: $(where $($where_clause:tt)*)?
361                {
362                    $($struct_content:tt)*
363                }
364            ) => {
365                // 1. the struct definition block with trait variable fields:
366                $(#[$struct_attr])*
367                $vis struct $struct_name
368                $(<$($generic_param),* $(, $generic_lifetime)*>)?
369                // TODO: $(where $($where_clause)*)?
370                {
371                    $($struct_content)*
372                    #(
373                        #trait_fields_in_struct
374                    )*
375                }
376                // 2. the struct impl block for the hidden parent trait:
377                impl
378                // 2.1 the struct generic+lifetime parameters, if any
379                $(<$($generic_param),* $(, $generic_lifetime)*>)?
380                // 2.2 the hidden parent trait
381                #hidden_parent_trait_with_bounds
382                for
383                // 2.3 the struct name with generic parameters, if any
384                $struct_name
385                $(<$($generic_param),* $(, $generic_lifetime)*>)?
386                {
387                    #(
388                        #parent_trait_methods_impls_in_struct
389                    )*
390                }
391            };
392        }
393    };
394
395    // 6. integrate all expanded code
396    proc_macro::TokenStream::from(quote! {
397        #expanded_trait_code
398        #declarative_macro_code
399    })
400}
401
402/// Define the struct to represent the attribute macro(`trait_var`) arguments.
403struct AttrArgs(Ident);
404impl syn::parse::Parse for AttrArgs {
405    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
406        let ident = input.parse()?;
407        Ok(AttrArgs(ident))
408    }
409}
410/// attribute macro: used to tag Rust struct like: `#[trait_var(<trait_name>)]`
411#[proc_macro_attribute]
412pub fn trait_var(
413    args: proc_macro::TokenStream,
414    input: proc_macro::TokenStream,
415) -> proc_macro::TokenStream {
416    // Convert TokenStream to ParseStream
417    let AttrArgs(trait_name) = parse_macro_input!(args as AttrArgs);
418
419    // parse input, only accept `struct`
420    let input_struct = parse_macro_input!(input as syn::ItemStruct);
421    let visible = &input_struct.vis;
422    let struct_name = &input_struct.ident;
423    let generics = &input_struct.generics;
424
425    let mut struct_searcher = PathFinder::new(struct_name.to_string(), true);
426    let trait_name_str = trait_name.to_string();
427    let mut trait_searcher = PathFinder::new(trait_name_str.clone(), false);
428    let trait_def_path = trait_searcher.get_def_path();
429    assert!(
430        !trait_def_path.is_empty(),
431        "The path for trait `{trait_name}` should NOT be empty!"
432    );
433    let import_statement_tokenstream = if trait_def_path == struct_searcher.get_def_path() {
434        quote! {}
435    } else {
436        let import_statement = trait_searcher.get_hidden_import_statement();
437        syn::parse_str::<TokenStream>(&import_statement)
438            .expect("Failed to parse import statement to TokenStream")
439    };
440
441    // handle different visibility of the struct fields
442    // NOTE: the `original_struct_fields` does not include the hidden trait variable fields
443    let original_struct_fields = input_struct.fields.iter().map(|f| {
444        let field_vis = &f.vis;
445        let field_ident = &f.ident;
446        let field_ty = &f.ty;
447        quote! {
448            #field_vis #field_ident: #field_ty,
449        }
450    });
451
452    // expand code
453    let trait_macro_name = Ident::new(&format!("{}_for_struct", trait_name), trait_name.span());
454    let _hidden_parent_trait_name = Ident::new(&format!("_{}", trait_name), trait_name.span());
455    let expanded = quote! {
456        #import_statement_tokenstream
457        #trait_macro_name! {
458            #visible struct #struct_name #generics {
459                #(#original_struct_fields)*
460            }
461        }
462    };
463
464    // return
465    expanded.into()
466}