trait_tests/
lib.rs

1#![feature(proc_macro)] //proc_macro_lib
2#![crate_type = "proc-macro"]
3
4extern crate proc_macro;
5extern crate proc_macro2;
6extern crate syn;
7#[macro_use]
8extern crate quote;
9
10//
11// Example https://github.com/actix/actix-derive/blob/master/src/lib.rs
12//
13
14use proc_macro::TokenStream;
15use proc_macro2::Span;
16use quote::TokenStreamExt;
17use syn::token::Comma;
18use syn::{
19    AngleBracketedGenericArguments, Binding, FnDecl, GenericArgument, Ident, Item, ItemImpl,
20    ItemTrait, MethodSig, Path, PathArguments, PathSegment, ReturnType, TraitBound, TraitItem,
21    TraitItemMethod, Type, TypeParamBound, TypePath,
22};
23
24#[cfg_attr(feature = "cargo-clippy", allow(needless_pass_by_value))]
25#[proc_macro_attribute]
26pub fn trait_tests(_attr: TokenStream, input: TokenStream) -> TokenStream {
27    // Construct a string representation of the type definition
28
29    //TODO: Error if test trait is not pub.
30    let output;
31    if let Ok(trait_def) = syn::parse(input.clone()) {
32        let mut trait_def: syn::ItemTrait = trait_def;
33        trait_def = inject_test_all_method(trait_def);
34        output = quote!(#trait_def); //TODO looses span information!
35
36        let mut tokens = proc_macro2::TokenStream::new();
37
38        let trait_name_str = trait_def.ident.clone();
39
40        let p: TypeParamBound = trait_def
41            .supertraits
42            .iter()
43            .nth(0)
44            .expect("trait should have a supertrait that you are testing.")
45            .clone();
46
47        if let TypeParamBound::Trait(TraitBound { path, .. }) = p {
48            let first_segment = path.segments.iter().nth(0).unwrap();
49            if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
50                ref args, ..
51            }) = first_segment.arguments
52            {
53                for (i, generic_arg) in args.iter().enumerate() {
54                    match generic_arg {
55                        GenericArgument::Type(gtype) => {
56                            let typename = Ident::new(
57                                &format!("{}Type{}", trait_name_str, i + 1),
58                                Span::call_site(),
59                            );
60                            tokens.append_all(
61                                quote!(#[allow(dead_code)] pub type #typename = #gtype;),
62                            );
63                        }
64                        GenericArgument::Binding(Binding {
65                            ty: gtype,
66                            ident: _ident,
67                            ..
68                        }) => {
69                            let typename = Ident::new(
70                                &format!("{}Type{}", trait_name_str, i + 1),
71                                Span::call_site(),
72                            );
73                            tokens.append_all(
74                                quote!(#[allow(dead_code)] pub type #typename = #gtype;),
75                            );
76                        }
77                        _ => { /* ignore */ }
78                    }
79                }
80            }
81        }
82
83        //Add in type definitions...
84        tokens.append_all(output);
85        //println!("trait_def: {:#?}", &tokens);
86        return tokens.into();
87    } else {
88        panic!("Expected this attribute to be on a trait.");
89    }
90}
91
92#[cfg_attr(feature = "cargo-clippy", allow(needless_pass_by_value))]
93#[proc_macro_attribute]
94pub fn test_impl(_attr: TokenStream, input: TokenStream) -> TokenStream {
95    let mut results = proc_macro2::TokenStream::new();
96    let ast: syn::Item =
97        syn::parse(input.clone()).expect("Unexpected - needs to be on impl X for Y");
98
99    results.append_all(quote!(#ast)); //TODO looses span information!
100
101    if let Item::Impl(ItemImpl {
102        trait_: Some((_opt, trait_ident, _for)),
103        self_ty,
104        ..
105    }) = ast
106    {
107        if let Type::Path(ref struct_type) = *self_ty {
108            let TypePath { path, .. } = struct_type.clone();
109            let Path { segments, .. } = path;
110            let seg1: PathSegment = segments[0].clone();
111            let PathSegment { arguments, .. } = seg1;
112            if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) =
113                arguments
114            {
115                let mut arg_uments = vec![];
116                for _arg in args.iter() {
117                    arg_uments.push(quote!(a));
118                }
119
120                results.append_all(process_case(struct_type, &trait_ident, &arg_uments));
121            } else {
122                results.append_all(process_case(struct_type, &trait_ident, &[]));
123            }
124        } else {
125            panic!("needs to be on an impl...");
126        }
127    } else {
128        panic!("needs to be on impl.");
129    }
130    //println!("trait_impl: {:#?}", &results);
131    results.into()
132}
133
134fn process_case(
135    struct_ident: &TypePath,
136    trait_path: &Path,
137    impltypes_y: &[proc_macro2::TokenStream],
138) -> proc_macro2::TokenStream {
139    let test_fn_name = generate_unique_test_name(struct_ident, trait_path, &impltypes_y);
140
141    let mut impltypes_punctuated = proc_macro2::TokenStream::new();
142    let (trait_name, num_params_trait_takes) =
143        get_type_with_filled_in_type_params_trait(trait_path);
144    let trait_name_str = quote!(#trait_name).to_string();
145
146    let mut v = vec![];
147    for (i, _) in impltypes_y.iter().enumerate() {
148        v.push(Ident::new(
149            &format!("{}Type{}", trait_name_str, i),
150            Span::call_site(),
151        ))
152    }
153
154    impltypes_punctuated.append_separated(v, quote!(,));
155
156    let TypePath { path, .. } = struct_ident;
157    let impl_type_name =
158        get_type_with_filled_in_type_params_impl(path, &trait_name_str, num_params_trait_takes);
159
160    quote!( #[test]
161            fn #test_fn_name() {
162                <#impl_type_name as #trait_name>::test_all();
163            }
164
165            impl #trait_name for #impl_type_name {})
166}
167
168fn get_type_with_filled_in_type_params_trait(trait_path: &Path) -> (PathSegment, usize) {
169    let Path { segments, .. } = trait_path;
170    if segments.len() > 1 {
171        panic!("untested");
172    } else {
173        let PathSegment { ident, arguments } = segments[0].clone();
174        let arg_num = match arguments {
175            PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => {
176                args.len()
177            }
178            PathArguments::None => 0,
179            _ => panic!("unimplemented"),
180        };
181        (
182            PathSegment {
183                ident: Ident::new(&(quote!(#ident).to_string() + "Tests"), Span::call_site()),
184                arguments: PathArguments::None,
185            },
186            arg_num,
187        )
188    }
189}
190
191fn get_type_with_filled_in_type_params_impl(
192    impl_path: &Path,
193    trait_name: &str,
194    num_params_trait_takes: usize,
195) -> PathSegment {
196    let Path { segments, .. } = impl_path;
197    if segments.len() > 1 {
198        panic!("untested");
199    } else {
200        let PathSegment { ident, arguments } = segments[0].clone();
201        if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) =
202            arguments
203        {
204            //Impl has arguments.
205            if num_params_trait_takes == args.len() {
206                let mut s = String::from("<");
207                for arg in 0..num_params_trait_takes {
208                    if arg > 0 {
209                        s.push(',');
210                    }
211                    s.push_str(&format!("{}Type{}", trait_name, arg + 1));
212                }
213                s.push('>');
214                let final_args: PathArguments = if num_params_trait_takes == 0 {
215                    PathArguments::None
216                } else {
217                    let ppf: AngleBracketedGenericArguments = syn::parse_str(&s).unwrap();
218                    PathArguments::AngleBracketed(ppf)
219                };
220                PathSegment {
221                    ident: Ident::new(&(quote!(#ident).to_string()), Span::call_site()),
222                    arguments: final_args,
223                }
224            } else if num_params_trait_takes == 0 {
225                //Case trait has no generic params, impl has generic params.
226                //If these are non-concrete types we should substitute them.
227                //For now we consider single letter 'T', 'U' etc. as being non-concrete types.
228                let mut next_arg_num = 1;
229                let mut concrete_args =
230                    syn::punctuated::Punctuated::<GenericArgument, Comma>::new();
231                for arg in args {
232                    let arg_len = quote!(#arg).to_string().len();
233                    if arg_len < 2 {
234                        let new_arg = format!("{}Type{}", trait_name, next_arg_num);
235                        let ga: GenericArgument = syn::parse_str(&new_arg).unwrap();
236                        concrete_args.push_value(ga);
237                        next_arg_num += 1;
238                    } else {
239                        concrete_args.push_value(arg);
240                    }
241                }
242
243                //leave well alone - keep the arguments as these are likely to be concrete types rather than bindings...:
244
245                PathSegment {
246                    ident: Ident::new(&(quote!(#ident).to_string()), Span::call_site()),
247                    arguments: PathArguments::AngleBracketed(AngleBracketedGenericArguments {
248                        colon2_token: None,
249                        lt_token: syn::token::Lt([Span::call_site()]),
250                        args: concrete_args,
251                        gt_token: syn::token::Gt([Span::call_site()]),
252                    }),
253                }
254            } else {
255                panic!("consider case");
256            }
257        } else {
258            //Case no angle bracketed args on impl
259            PathSegment {
260                ident: Ident::new(&(quote!(#ident).to_string()), Span::call_site()),
261                arguments,
262            }
263        }
264    }
265}
266
267fn generate_unique_test_name(
268    struct_ident: &TypePath,
269    trait_name: &Path,
270    params: &[proc_macro2::TokenStream],
271) -> Ident {
272    let mut root =quote!(#struct_ident).to_string();
273    root.push('_');
274    root.push_str(&quote!(#trait_name).to_string());
275    for param in params {
276        root.push('_');
277        root.push_str(&param.clone().to_string());
278    }
279    syn::Ident::new(
280        &root
281            .to_lowercase()
282            .replace("<", "_")
283            .replace(">", "")
284            .replace("\"", "")
285            .replace(" ", "_")
286            .replace(",", "_")
287            .replace("__", "_")
288            .replace("__", "_"),
289        Span::call_site(),
290    )
291}
292
293fn inject_test_all_method(trait_def: ItemTrait) -> ItemTrait {
294    let mut items = trait_def.items.clone();
295    let mut test_calls: Vec<Ident> = Vec::new();
296    for item in &items {
297        if let TraitItem::Method(TraitItemMethod {
298            sig:
299                MethodSig {
300                    ident: ref a,
301                    decl:
302                        FnDecl {
303                            output: ReturnType::Default,
304                            inputs: ref args,
305                            ..
306                        },
307                    ..
308                },
309            ..
310        }) = item
311        {
312            if args.is_empty() {
313                test_calls.push(a.clone());
314            }
315        }
316    }
317
318    let test_all_fn = syn::parse(
319        quote!(
320        fn test_all() {
321            #(Self::#test_calls());*
322        }
323    ).into(),
324    ).unwrap();
325
326    items.push(test_all_fn);
327    syn::ItemTrait { items, ..trait_def }
328}