trait_variable_macros/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use regex::{Captures, Regex};
6use syn::{braced, token, Visibility};
7use syn::{
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11    Ident, Token, TraitItem, Type,
12};
13
14struct TraitVarField {
15    var_vis: Visibility,
16    var_name: Ident,
17    _colon_token: Token![:],
18    ty: Type,
19}
20impl Parse for TraitVarField {
21    fn parse(input: ParseStream) -> syn::Result<Self> {
22        Ok(TraitVarField {
23            var_vis: input.parse()?,
24            var_name: input.parse()?,
25            _colon_token: input.parse()?,
26            ty: input.parse()?,
27        })
28    }
29}
30
31struct TraitInput {
32    trait_vis: Visibility,
33    _trait_token: Token![trait],
34    trait_name: Ident,
35    _brace_token: token::Brace,
36    trait_variables: Punctuated<TraitVarField, Token![;]>,
37    trait_items: Vec<TraitItem>,
38}
39
40impl Parse for TraitInput {
41    fn parse(input: ParseStream) -> syn::Result<Self> {
42        let content;
43        Ok(TraitInput {
44            trait_vis: input.parse()?,
45            _trait_token: input.parse()?,
46            trait_name: input.parse()?,
47            _brace_token: braced!(content in input),
48            // Parse all variable declarations until a method or end of input is encountered
49            trait_variables: {
50                let mut vars = Punctuated::new();
51                while !content.peek(Token![fn]) && !content.peek(Token![;]) && !content.is_empty() {
52                    vars.push_value(content.parse()?);
53                    // Ensure that a semicolon follows the variable declaration
54                    if !content.peek(Token![;]) {
55                        return Err(content.error("expected `;` after variable declaration"));
56                    }
57                    vars.push_punct(content.parse()?);
58                }
59                vars
60            },
61            // Parse all method declarations
62            trait_items: {
63                let mut items = Vec::new();
64                while !content.is_empty() {
65                    items.push(content.parse()?);
66                }
67                items
68            },
69        })
70    }
71}
72
73/// functional macro: used to generate code for a trait with variable fields
74#[proc_macro]
75pub fn trait_variable(input: TokenStream) -> TokenStream {
76    let TraitInput {
77        trait_vis,
78        trait_name,
79        trait_variables,
80        trait_items,
81        ..
82    } = parse_macro_input!(input as TraitInput);
83    // 1.1 get parent trait name
84    let parent_trait_name = Ident::new(&format!("_{}", trait_name), trait_name.span());
85    // 1.2 get trait declarative macro name
86    let trait_decl_macro_name =
87        Ident::new(&format!("{}_for_struct", trait_name), trait_name.span());
88
89    // 2.1 generate parent trait methods declaration
90    let parent_trait_methods =
91        trait_variables
92            .iter()
93            .map(|TraitVarField { var_name, ty, .. }| {
94                let method_name = Ident::new(&format!("_{}", var_name), var_name.span());
95                let method_name_mut = Ident::new(&format!("_{}_mut", var_name), var_name.span());
96                quote! {
97                    fn #method_name(&self) -> &#ty;
98                    fn #method_name_mut(&mut self) -> &mut #ty;
99                }
100            });
101    // 2.2 generate trait variable fields definition for structs later
102    let struct_trait_fields_defs = trait_variables.iter().map(
103        |TraitVarField {
104             var_vis,
105             var_name,
106             ty,
107             ..
108         }| {
109            quote! {
110                #var_vis #var_name: #ty,
111            }
112        },
113    );
114    // 2.3 generate parent trait methods implementation for struct
115    let parent_trait_methods_impls =
116        trait_variables
117            .iter()
118            .map(|TraitVarField { var_name, ty, .. }| {
119                let method_name = Ident::new(&format!("_{}", var_name), var_name.span());
120                let method_name_mut = Ident::new(&format!("_{}_mut", var_name), var_name.span());
121                quote! {
122                    fn #method_name(&self) -> &#ty{
123                        &self.#var_name
124                    }
125                    fn #method_name_mut(&mut self) -> &mut #ty{
126                        &mut self.#var_name
127                    }
128                }
129            });
130
131    // 3. refine the body of methods from the original trait
132    // let original_trait_items = trait_items;
133    let original_trait_items = trait_items.into_iter().map(|item| {
134        if let TraitItem::Method(mut method) = item {
135            if let Some(body) = &mut method.default {
136                // Use regular expressions or other methods to find and replace text
137                let re = Regex::new(r"sself\.([a-zA-Z_]\w*)").unwrap();
138                let body_str = quote!(#body).to_string();
139                let new_body_str = re
140                    .replace_all(&body_str, |caps: &Captures| {
141                        let name = &caps[1];
142                        // Check if it is followed by braces
143                        if body_str.contains(&format!("{}(", name)) {
144                            format!("self.{}", name)
145                        } else {
146                            format!("self._{}()", name)
147                        }
148                    })
149                    .to_string();
150
151                let new_body: TokenStream = new_body_str.parse().expect("Failed to parse new body");
152                method.default = Some(syn::parse(new_body).expect("Failed to parse method body"));
153            }
154            quote! { #method }
155        } else {
156            quote! { #item }
157        }
158    });
159
160    // 4. generate the hidden declarative macro for target struct
161    let decl_macro_code = quote! {
162        #[doc(hidden)]
163        #[macro_export] // it is ok to always export the declarative macro
164        macro_rules! #trait_decl_macro_name { // NOTE: the reexpanded macro is used for rust struct only
165            (
166                ($hidden_parent_trait:path)
167                $(#[$struct_attr:meta])* // NOTE: make sure the style is consistent with that in arm 2 output
168                $vis:vis struct $struct_name:ident {
169                    $($struct_content:tt)*
170                }
171            ) => {
172                $(#[$struct_attr])*
173                $vis struct $struct_name {
174                    $($struct_content)*
175                    #(
176                        #struct_trait_fields_defs
177                    )*
178                }
179                impl $hidden_parent_trait for $struct_name {
180                    #(
181                        #parent_trait_methods_impls
182                    )*
183                }
184            };
185        }
186    };
187    // 5. expand the final code
188    let expanded = quote! {
189        #trait_vis trait #parent_trait_name {
190            #(#parent_trait_methods)*
191        }
192        #trait_vis trait #trait_name: #parent_trait_name {
193            #(#original_trait_items)*
194        }
195
196        #decl_macro_code
197    };
198    TokenStream::from(expanded)
199}
200
201/// attribute macro: used to tag Rust struct like: `#[trait_var(<trait_name>)]`
202#[proc_macro_attribute]
203pub fn trait_var(args: TokenStream, input: TokenStream) -> TokenStream {
204    // parse attributes
205    let args = parse_macro_input!(args as syn::AttributeArgs);
206    let trait_name = match args.first().unwrap() {
207        syn::NestedMeta::Meta(syn::Meta::Path(path)) => path.get_ident().unwrap(),
208        _ => panic!("Expected a trait name"),
209    };
210
211    // parse input, only accept `struct`
212    let input_struct = parse_macro_input!(input as syn::ItemStruct);
213    let visible = &input_struct.vis;
214    let struct_name = &input_struct.ident;
215
216    // handle different visibility of the struct fields
217    let struct_fields = input_struct.fields.iter().map(|f| {
218        let field_vis = &f.vis;
219        let field_ident = &f.ident;
220        let field_ty = &f.ty;
221        quote! {
222            #field_vis #field_ident: #field_ty,
223        }
224    });
225
226    // expand code
227    let trait_macro_name = Ident::new(&format!("{}_for_struct", trait_name), trait_name.span());
228    let parent_trait_name = Ident::new(&format!("_{}", trait_name), trait_name.span());
229    let expanded = quote! {
230        #trait_macro_name! {
231            (#parent_trait_name)
232            // (#hidden_trait_path) // TODO: delete?
233            #visible struct #struct_name {
234                #(#struct_fields)*
235            }
236        }
237    };
238
239    // return
240    expanded.into()
241}