1mod path_utils;
2mod trait_item;
3mod trait_utils;
4
5use path_utils::PathFinder;
6use proc_macro2::TokenStream;
7
8#[allow(unused)]
9use quote::{quote, ToTokens};
10
11use syn::visit::{self, Visit};
12use syn::{
13 braced, token, AngleBracketedGenericArguments, GenericArgument, Generics, PathArguments, Type,
14 TypeParamBound, TypePath, Visibility, WhereClause,
15};
16use syn::{
17 parse::{Parse, ParseStream},
18 parse_macro_input,
19 punctuated::Punctuated,
20 Ident, Token, TraitItem,
21};
22use trait_item::refine_trait_items;
23
24struct GenericTypeVisitor {
25 generics: Vec<String>,
26}
27impl GenericTypeVisitor {
28 fn is_single_upper_letter(&self, ident_str: &str) -> bool {
29 ident_str.len() == 1 && ident_str.chars().next().unwrap().is_uppercase()
30 }
31}
32impl<'ast> Visit<'ast> for GenericTypeVisitor {
33 fn visit_type(&mut self, i: &'ast Type) {
34 if let Type::Path(TypePath { path, .. }) = i {
35 if let Some(PathArguments::AngleBracketed(AngleBracketedGenericArguments {
36 args,
37 ..
38 })) = path.segments.last().map(|seg| &seg.arguments)
39 {
40 for arg in args {
41 if let GenericArgument::Type(Type::Path(tp)) = arg {
42 if let Some(ident) = tp.path.get_ident() {
43 let ident_str = ident.to_string();
44 if self.is_single_upper_letter(&ident_str)
45 && !self.generics.contains(&ident_str)
46 {
47 self.generics.push(ident_str);
48 }
49 }
50 }
51 }
52 } else if let Some(seg) = path.segments.last() {
53 let ident_str = seg.ident.to_string();
54 if self.is_single_upper_letter(&ident_str) && !self.generics.contains(&ident_str) {
55 self.generics.push(ident_str);
56 }
57 }
58 }
59 visit::visit_type(self, i);
61 }
62}
63#[test]
64fn test_generic_type_visitor() {
65 let code = quote! { V }; let syntax_tree: syn::Type = syn::parse2(code).unwrap();
68 let mut visitor = GenericTypeVisitor {
69 generics: Vec::new(),
70 };
71 visitor.visit_type(&syntax_tree);
72
73 assert_eq!(visitor.generics, vec!["V"]);
74 let code = quote! { Vec<T, HashMap<K, V>> }; let syntax_tree: syn::Type = syn::parse2(code).unwrap();
77 let mut visitor = GenericTypeVisitor {
78 generics: Vec::new(),
79 };
80 visitor.visit_type(&syntax_tree);
81
82 assert_eq!(visitor.generics, vec!["T", "K", "V"]);
83}
84
85struct TraitVarField {
87 var_vis: Visibility,
88 var_name: Ident,
89 type_name: Type,
90 type_generics: Vec<String>,
91}
92impl Parse for TraitVarField {
93 fn parse(input: ParseStream) -> syn::Result<Self> {
94 let var_vis: Visibility = input.parse().expect("Failed to Parse to `var_vis`");
95 let var_name: Ident = input.parse().expect("Failed to Parse to `var_name`");
96 let _: Token![:] = input.parse().expect("Failed to Parse to `:`");
97 let type_name: Type = input.parse().expect("Failed to Parse to `type_name`");
98 let type_generics = {
99 let mut visitor = GenericTypeVisitor {
100 generics: Vec::new(),
101 };
102 visitor.visit_type(&type_name);
103 visitor.generics
104 };
105 Ok(TraitVarField {
106 var_vis,
107 var_name,
108 type_name,
109 type_generics,
110 })
111 }
112}
113#[test]
114fn test_trait_var_field() {
115 let raw_code = quote! { pub var_name: Vec<T, HashMap<K, V>> };
116 let parsed =
117 syn::parse2::<TraitVarField>(raw_code).expect("Failed to parse to `TraitVarField`");
118
119 assert!(
120 matches!(parsed.var_vis, Visibility::Public(_)),
121 "Visibility is not public"
122 );
123 assert_eq!(parsed.var_name.to_string(), "var_name".to_string());
124 assert_eq!(
125 parsed.type_name.to_token_stream().to_string(),
126 "Vec < T , HashMap < K , V > >".to_string()
127 );
128 assert_eq!(
129 parsed.type_generics,
130 vec!["T".to_string(), "K".to_string(), "V".to_string()]
131 );
132}
133
134struct TraitInput {
135 trait_vis: Visibility,
136 _trait_token: Token![trait],
137 trait_name: Ident,
138 trait_bounds: Option<Generics>, explicit_parent_traits: Option<Punctuated<TypeParamBound, Token![+]>>, where_clause: Option<WhereClause>, _brace_token: token::Brace,
142 trait_variables: Vec<TraitVarField>,
143 trait_items: Vec<TraitItem>,
144}
145
146impl Parse for TraitInput {
147 fn parse(input: ParseStream) -> syn::Result<Self> {
148 let content;
149
150 Ok(TraitInput {
151 trait_vis: input.parse()?,
152 _trait_token: input.parse()?,
153 trait_name: input.parse()?,
154 trait_bounds: if input.peek(Token![<]) {
155 Some(input.parse()?) } else {
157 None
158 },
159 explicit_parent_traits: if input.peek(Token![:]) {
160 input.parse::<Token![:]>()?;
161 let mut parent_traits = Punctuated::new();
162 while !input.peek(Token![where]) && !input.peek(token::Brace) {
163 parent_traits.push_value(input.parse()?);
164 if input.peek(Token![+]) {
165 parent_traits.push_punct(input.parse()?);
166 } else {
167 break;
168 }
169 }
170 Some(parent_traits)
171 } else {
172 None
173 },
174 where_clause: if input.peek(syn::token::Where) {
175 Some(input.parse()?)
176 } else {
177 None
178 },
179 _brace_token: braced!(content in input),
180 trait_variables: {
182 let mut v = Vec::new();
183 while !content.peek(Token![type])
184 && !content.peek(Token![const])
185 && !content.peek(Token![fn])
186 && !content.is_empty()
187 {
188 v.push(content.call(TraitVarField::parse)?);
189 let _: Token![;] = content.parse()?;
190 }
191 v
192 },
193 trait_items: {
194 let mut items = Vec::new();
195 while !content.is_empty() {
196 items.push(content.parse()?);
197 }
198 items
199 },
200 })
201 }
202}
203
204#[test]
205fn test_trait_input() {
206 let raw_code = quote! {
207 pub trait MyTrait {
208 x: Vec<T, HashMap<K, V>>;
209 pub y: bool;
210
211 fn print_x(&self){
212 println!("x: `{}`", self.x);
213 }
214 fn print_y(&self){
215 println!("y: `{}`", self.y);
216 }
217 fn print_all(&self);
218 }
219 };
220 let parsed = syn::parse2::<TraitInput>(raw_code).unwrap();
221
222 assert!(matches!(parsed.trait_vis, Visibility::Public(_)));
223 assert_eq!(parsed.trait_name.to_string(), "MyTrait".to_string());
224 assert!(parsed.trait_bounds.is_none());
225 assert!(parsed.explicit_parent_traits.is_none());
226 assert!(parsed.where_clause.is_none());
227 assert_eq!(parsed.trait_variables.len(), 2);
228 assert_eq!(
229 parsed.trait_variables[0].var_name.to_string(),
230 "x".to_string()
231 );
232 assert_eq!(
233 parsed.trait_variables[1].var_name.to_string(),
234 "y".to_string()
235 );
236 assert_eq!(parsed.trait_items.len(), 3);
237 assert_eq!(
238 parsed.trait_items[0].to_token_stream().to_string(),
239 "fn print_x (& self) { println ! (\"x: `{}`\" , self . x) ; }".to_string()
240 );
241 assert_eq!(
242 parsed.trait_items[1].to_token_stream().to_string(),
243 "fn print_y (& self) { println ! (\"y: `{}`\" , self . y) ; }".to_string()
244 );
245 assert_eq!(
246 parsed.trait_items[2].to_token_stream().to_string(),
247 "fn print_all (& self) ;".to_string()
248 );
249}
250
251#[proc_macro]
253pub fn trait_variable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
254 let TraitInput {
255 trait_vis,
256 trait_name,
257 trait_bounds,
258 explicit_parent_traits,
259 where_clause,
260 trait_variables,
261 trait_items,
262 ..
263 } = parse_macro_input!(input as TraitInput);
264
265 let hidden_parent_trait_name = Ident::new(&format!("_{}", trait_name), trait_name.span());
267 let trait_decl_macro_name =
269 Ident::new(&format!("{}_for_struct", trait_name), trait_name.span());
270
271 let hidden_parent_trait_methods_signatures = trait_variables.iter().map(
273 |TraitVarField {
274 var_name,
275 type_name,
276 ..
277 }| {
278 let method_name = Ident::new(&format!("_{}", var_name), var_name.span());
279 let method_name_mut = Ident::new(&format!("_{}_mut", var_name), var_name.span());
280 quote! {
281 fn #method_name(&self) -> &#type_name;
282 fn #method_name_mut(&mut self) -> &mut #type_name;
283 }
284 },
285 );
286 let trait_fields_in_struct = trait_variables.iter().map(
288 |TraitVarField {
289 var_vis,
290 var_name,
291 type_name,
292 ..
293 }| {
294 quote! {
295 #var_vis #var_name: #type_name,
296 }
297 },
298 );
299 let parent_trait_methods_impls_in_struct = trait_variables.iter().map(
301 |TraitVarField {
302 var_name,
303 type_name,
304 ..
305 }| {
306 let method_name = Ident::new(&format!("_{}", var_name), var_name.span());
307 let method_name_mut = Ident::new(&format!("_{}_mut", var_name), var_name.span());
308 quote! {
309 fn #method_name(&self) -> &#type_name{
310 &self.#var_name
311 }
312 fn #method_name_mut(&mut self) -> &mut #type_name{
313 &mut self.#var_name
314 }
315 }
316 },
317 );
318 let hidden_parent_trait_bounds = {
320 let mut generic_types = Vec::new();
321 for trait_var in trait_variables.iter() {
322 for generic in &trait_var.type_generics {
323 let generic_ident = syn::Ident::new(generic, proc_macro2::Span::call_site());
324 if !generic_types.contains(&generic_ident) {
325 generic_types.push(generic_ident);
326 }
327 }
328 }
329 if !generic_types.is_empty() {
330 quote! { <#(#generic_types),*> }
331 } else {
332 TokenStream::new()
333 }
334 };
335
336 let trait_items = refine_trait_items(trait_items);
338
339 let hidden_parent_trait_with_bounds =
341 quote! {#hidden_parent_trait_name #hidden_parent_trait_bounds};
342 let expanded_trait_code = quote! {
343 #trait_vis trait #hidden_parent_trait_with_bounds {
344 #(#hidden_parent_trait_methods_signatures)*
345 }
346 #trait_vis trait #trait_name #trait_bounds: #hidden_parent_trait_with_bounds + #explicit_parent_traits #where_clause {
347 #(#trait_items)*
348 }
349 };
350
351 let declarative_macro_code = quote! {
353 #[doc(hidden)]
354 #[macro_export] macro_rules! #trait_decl_macro_name { (
357 $(#[$struct_attr:meta])* $vis:vis struct $struct_name:ident
359 $(<$($generic_param:ident),* $(, $generic_lifetime:lifetime)* $(,)? >)?
360 {
362 $($struct_content:tt)*
363 }
364 ) => {
365 $(#[$struct_attr])*
367 $vis struct $struct_name
368 $(<$($generic_param),* $(, $generic_lifetime)*>)?
369 {
371 $($struct_content)*
372 #(
373 #trait_fields_in_struct
374 )*
375 }
376 impl
378 $(<$($generic_param),* $(, $generic_lifetime)*>)?
380 #hidden_parent_trait_with_bounds
382 for
383 $struct_name
385 $(<$($generic_param),* $(, $generic_lifetime)*>)?
386 {
387 #(
388 #parent_trait_methods_impls_in_struct
389 )*
390 }
391 };
392 }
393 };
394
395 proc_macro::TokenStream::from(quote! {
397 #expanded_trait_code
398 #declarative_macro_code
399 })
400}
401
402struct AttrArgs(Ident);
404impl syn::parse::Parse for AttrArgs {
405 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
406 let ident = input.parse()?;
407 Ok(AttrArgs(ident))
408 }
409}
410#[proc_macro_attribute]
412pub fn trait_var(
413 args: proc_macro::TokenStream,
414 input: proc_macro::TokenStream,
415) -> proc_macro::TokenStream {
416 let AttrArgs(trait_name) = parse_macro_input!(args as AttrArgs);
418
419 let input_struct = parse_macro_input!(input as syn::ItemStruct);
421 let visible = &input_struct.vis;
422 let struct_name = &input_struct.ident;
423 let generics = &input_struct.generics;
424
425 let mut struct_searcher = PathFinder::new(struct_name.to_string(), true);
426 let trait_name_str = trait_name.to_string();
427 let mut trait_searcher = PathFinder::new(trait_name_str.clone(), false);
428 let trait_def_path = trait_searcher.get_def_path();
429 assert!(
430 !trait_def_path.is_empty(),
431 "The path for trait `{trait_name}` should NOT be empty!"
432 );
433 let import_statement_tokenstream = if trait_def_path == struct_searcher.get_def_path() {
434 quote! {}
435 } else {
436 let import_statement = trait_searcher.get_hidden_import_statement();
437 syn::parse_str::<TokenStream>(&import_statement)
438 .expect("Failed to parse import statement to TokenStream")
439 };
440
441 let original_struct_fields = input_struct.fields.iter().map(|f| {
444 let field_vis = &f.vis;
445 let field_ident = &f.ident;
446 let field_ty = &f.ty;
447 quote! {
448 #field_vis #field_ident: #field_ty,
449 }
450 });
451
452 let trait_macro_name = Ident::new(&format!("{}_for_struct", trait_name), trait_name.span());
454 let _hidden_parent_trait_name = Ident::new(&format!("_{}", trait_name), trait_name.span());
455 let expanded = quote! {
456 #import_statement_tokenstream
457 #trait_macro_name! {
458 #visible struct #struct_name #generics {
459 #(#original_struct_fields)*
460 }
461 }
462 };
463
464 expanded.into()
466}