1#![cfg_attr(nightly, feature(proc_macro_diagnostic))]
10
11use std::collections::HashSet;
12
13use ctxt::Ctxt;
14use quote::{quote, quote_spanned, ToTokens};
15use syn::spanned::Spanned;
16use syn::{parse_macro_input, parse_quote, Attribute, DeriveInput, GenericParam, Generics, Meta, NestedMeta, Path};
17
18mod ctxt;
19
20#[proc_macro_derive(TraitMapEntry, attributes(trait_map))]
24pub fn derive_trait_map_entry(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
25 let derive_input = parse_macro_input!(input as DeriveInput);
26
27 let name = derive_input.ident;
28
29 let generics = add_trait_bounds(derive_input.generics);
31 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
32
33 let mut ctx = Ctxt::new();
36 let traits_to_include: Vec<_> = parse_attributes(&mut ctx, derive_input.attrs)
37 .into_iter()
38 .filter_map(|attr| parse_attribute_trait(&mut ctx, attr))
39 .collect();
40
41 let mut found_traits = HashSet::new();
50 let mut duplicate_traits = Vec::new();
51 let functions: Vec<_> = traits_to_include
52 .into_iter()
53 .filter_map(|t| {
54 if found_traits.contains(&t) {
55 duplicate_traits.push(t);
56 None
57 } else {
58 let span = t.span();
59 let function = quote_spanned!(span => .add_trait::<dyn #t>());
60 found_traits.insert(t);
61 Some(function)
62 }
63 })
64 .collect();
65
66 if let Err(errors) = ctx.check() {
68 let compile_errors = errors.iter().map(syn::Error::to_compile_error);
69 return quote!(#(#compile_errors)*).into();
70 }
71
72 #[cfg(nightly)]
74 {
75 if found_traits.len() == 0 {
76 proc_macro::Span::call_site()
77 .warning("no traits specified for `TraitMapEntry`")
78 .help("specify one or more traits using `#[trait_map(TraitOne, some::path::TraitTwo, ...)]`")
79 .emit();
80 }
81
82 for t in duplicate_traits {
83 let path_str: String = t.to_token_stream().into_iter().map(|t| format!("{}", t)).collect();
84 t.span()
85 .unwrap()
86 .warning(format!("duplicate trait `{}`", path_str))
87 .note("including the same trait multiple times is a no-op")
88 .emit();
89 }
90 }
91
92 quote! {
93 impl #impl_generics trait_map::TraitMapEntry for #name #ty_generics #where_clause {
94 fn on_create<'a>(&mut self, context: trait_map::Context<'a>) {
95 context.downcast::<Self>() #(#functions)* ;
96 }
97 }
98 }
99 .into()
100}
101
102fn add_trait_bounds(mut generics: Generics) -> Generics {
104 for param in &mut generics.params {
105 if let GenericParam::Type(ref mut type_param) = *param {
106 type_param.bounds.push(parse_quote!('static));
107 }
108 }
109 generics
110}
111
112fn parse_attributes(ctx: &mut Ctxt, attributes: Vec<Attribute>) -> Vec<NestedMeta> {
114 attributes
115 .into_iter()
116 .filter(|attr| attr.path.is_ident("trait_map"))
119 .map(|attr| match attr.parse_meta() {
120 Ok(Meta::List(meta)) => meta.nested.into_iter().collect::<Vec<_>>(),
123
124 Ok(other) => {
127 ctx.error_spanned_by(other, "expected #[trait_map(...)]");
128 Vec::new()
129 },
130
131 Err(err) => {
132 ctx.syn_error(err);
133 Vec::new()
134 },
135 })
136 .flatten()
137 .collect()
138}
139
140fn parse_attribute_trait(ctx: &mut Ctxt, attr: NestedMeta) -> Option<Path> {
146 match attr {
147 NestedMeta::Meta(Meta::Path(trait_path)) => Some(trait_path),
151
152 other => {
156 ctx.error_spanned_by(other, "unexpected attribute, please specify a valid trait");
157 None
158 },
159 }
160}