traitreg_macros/
lib.rs

1use quote::quote;
2use syn::parse::{Parse, ParseStream};
3use syn::Ident;
4
5/// Register an implementation of a trait on a concrete type.
6///
7/// ```rust
8/// trait MyTrait {}
9/// struct MyType;
10///
11/// #[traitreg::register]
12/// impl MyTrait for MyType {}
13/// ```
14///
15/// Supports registration of a constructor, which can be any associated method with the signature
16/// `fn() -> Self`. For Example:
17///
18/// ```rust
19/// trait MyTrait {}
20///
21/// #[derive(Default)]
22/// struct MyType;
23///
24/// #[traitreg::register(default)]
25/// impl MyTrait for MyType {}
26///
27/// struct MyOtherType;
28/// impl MyOtherType {
29///     fn new() -> Self { Self }
30/// }
31///
32/// #[traitreg::register(new)]
33/// impl MyTrait for MyOtherType {}
34/// ```
35#[proc_macro_attribute]
36pub fn register(
37    attr: proc_macro::TokenStream,
38    item: proc_macro::TokenStream,
39) -> proc_macro::TokenStream {
40    // Read custom / default constructor from attribute if it exists
41    let constructor_fn = if attr.is_empty() {
42        None
43    } else {
44        Some(syn::parse_macro_input!(attr as RegisterAttribute))
45    };
46
47    let has_constructor = constructor_fn.is_some();
48    let has_constructor = quote! { #has_constructor };
49
50    let constructor_fn_call_str = if let Some(cfn) = constructor_fn {
51        let ident = cfn.constructor_fn_ident;
52        quote! {
53            Some(Box::new(Self::#ident()))
54        }
55    } else {
56        quote! {
57            None
58        }
59    };
60
61    let item_clone = item.clone();
62
63    let parsed_item = syn::parse_macro_input!(item as RegisterItem);
64    let item_impl = parsed_item.item;
65
66    let (trait_not, trait_path, _) = item_impl
67        .trait_
68        .expect("Can only register an implementation of a trait, 'impl <Trait> for <Type>'.");
69    assert!(
70        trait_not.is_none(),
71        "Cannot register inverted impl trait: 'impl !Trait for Type'."
72    );
73
74    let trait_ident = trait_path
75        .require_ident()
76        .expect("Expected trait in impl block to have an identifier.");
77    let trait_name = format!("{trait_ident}");
78
79    let type_path = get_self_type_path(&item_impl.self_ty);
80    let type_ident = type_path
81        .require_ident()
82        .expect("Expected type in impl block to have an identifier.");
83    let type_name = format!("{type_ident}");
84
85    let register_static_ident =
86        syn::parse_str::<syn::Ident>(format!("{}_{}__Register", type_ident, trait_ident).as_ref())
87            .expect("Unable to create identifier");
88    let register_static_fn_ident = syn::parse_str::<syn::Ident>(
89        format!("{}_{}__RegisterFn", type_ident, trait_ident).as_ref(),
90    )
91    .expect("Unable to create identifier");
92
93    let mut result: proc_macro::TokenStream = quote! {
94        impl traitreg::RegisteredImpl<Box<dyn #trait_path>> for #type_path {
95            const INSTANCIATE: fn() -> Option<Box<dyn #trait_path>> = || { #constructor_fn_call_str };
96            const HAS_CONSTRUCTOR: bool = #has_constructor;
97            const NAME: &'static str = #type_name;
98            const PATH: &'static str = stringify!(#type_path);
99            const FILE: &'static str = core::file!() ;
100            const MODULE_PATH: &'static str = core::module_path!();
101            const TRAIT_NAME: &'static str = #trait_name;
102        }
103
104        #[used]
105        #[cfg_attr(any(target_os = "linux", target_os = "android"), link_section = ".init_array.10000")]
106        #[cfg_attr(target_os = "freebsd", link_section = ".init_array.10000")]
107        #[cfg_attr(target_os = "netbsd", link_section = ".init_array.10000")]
108        #[cfg_attr(target_os = "openbsd", link_section = ".init_array.10000")]
109        #[cfg_attr(target_os = "dragonfly", link_section = ".init_array.10000")]
110        #[cfg_attr(target_os = "illumos", link_section = ".init_array.10000")]
111        #[cfg_attr(target_os = "haiku", link_section = ".init_array.10000")]
112        #[cfg_attr(target_vendor = "apple", link_section = "__DATA,__mod_init_func")]
113        #[cfg_attr(windows, link_section = ".CRT$XCT")]
114        static #register_static_ident: extern fn() = {
115            extern fn #register_static_fn_ident() {
116                traitreg::__register_impl::<Box<dyn #trait_path>, #type_path>();
117            }
118            #register_static_fn_ident
119        };
120    }.into();
121
122    result.extend(item_clone.clone());
123
124    result
125}
126
127/// Create a registry of implementations of a trait
128///
129/// ```rust
130/// trait MyTrait {}
131///
132/// #[traitreg::registry(MyTrait)]
133/// static MYTRAIT_REGISTRY: () = ();
134/// ```
135#[proc_macro_attribute]
136pub fn registry(
137    attr: proc_macro::TokenStream,
138    item: proc_macro::TokenStream,
139) -> proc_macro::TokenStream {
140    let registry_attr = syn::parse_macro_input!(attr as RegistryAttribute);
141    let registry_item = syn::parse_macro_input!(item as RegistryItem);
142
143    let trait_ident = registry_attr.trait_ident;
144    let item = registry_item.item;
145
146    let trait_name = format!("{trait_ident}");
147    let item_ident = item.ident;
148    let storage_ident = syn::parse_str::<syn::Ident>(format!("{}__STORAGE", item_ident).as_ref())
149        .expect("Unable to create identifier");
150    let wrapper_struct_ident =
151        syn::parse_str::<syn::Ident>(format!("{}__TraitReg", item_ident).as_ref())
152            .expect("Unable to create identifier");
153    let build_static_ident =
154        syn::parse_str::<syn::Ident>(format!("{}__Build", item_ident).as_ref())
155            .expect("Unable to create identifier");
156    let build_static_fn_ident =
157        syn::parse_str::<syn::Ident>(format!("{}__BuildFn", item_ident).as_ref())
158            .expect("Unable to create identifier");
159
160    quote! {
161        static mut #storage_ident: Option<traitreg::TraitRegStorage<Box<dyn #trait_ident>>> = None;
162
163        static #item_ident: #wrapper_struct_ident = #wrapper_struct_ident {};
164
165        struct #wrapper_struct_ident;
166
167        impl ::core::ops::Deref for #wrapper_struct_ident {
168            type Target = traitreg::TraitRegStorage<Box<dyn #trait_ident>>;
169            fn deref(&self) -> &'static traitreg::TraitRegStorage<Box<dyn #trait_ident>> {
170                unsafe {
171                    #storage_ident.as_ref().unwrap()
172                }
173            }
174        }
175
176        #[used]
177        #[cfg_attr(any(target_os = "linux", target_os = "android"), link_section = ".init_array.20000")]
178        #[cfg_attr(target_os = "freebsd", link_section = ".init_array.20000")]
179        #[cfg_attr(target_os = "netbsd", link_section = ".init_array.20000")]
180        #[cfg_attr(target_os = "openbsd", link_section = ".init_array.20000")]
181        #[cfg_attr(target_os = "dragonfly", link_section = ".init_array.20000")]
182        #[cfg_attr(target_os = "illumos", link_section = ".init_array.20000")]
183        #[cfg_attr(target_os = "haiku", link_section = ".init_array.20000")]
184        #[cfg_attr(target_vendor = "apple", link_section = "__DATA,__mod_init_func")]
185        #[cfg_attr(windows, link_section = ".CRT$XCU")]
186        static #build_static_ident: extern fn() = {
187            extern fn #build_static_fn_ident() {
188                let storage = traitreg::TraitRegStorage::<Box<dyn #trait_ident>>::__new(#trait_name);
189
190                unsafe {
191                    #storage_ident = Some(storage)
192                }
193            }
194            #build_static_fn_ident
195        };
196    }.into()
197}
198
199#[derive(Debug)]
200struct RegisterAttribute {
201    constructor_fn_ident: Ident,
202}
203
204impl Parse for RegisterAttribute {
205    fn parse(input: ParseStream) -> syn::Result<Self> {
206        Ok(Self {
207            constructor_fn_ident: Ident::parse(input)?,
208        })
209    }
210}
211
212struct RegisterItem {
213    item: syn::ItemImpl,
214}
215
216impl Parse for RegisterItem {
217    fn parse(input: ParseStream) -> syn::Result<Self> {
218        Ok(Self {
219            item: syn::ItemImpl::parse(input)?,
220        })
221    }
222}
223
224#[derive(Debug)]
225struct RegistryAttribute {
226    trait_ident: Ident,
227}
228
229impl Parse for RegistryAttribute {
230    fn parse(input: ParseStream) -> syn::Result<Self> {
231        Ok(Self {
232            trait_ident: Ident::parse(input)?,
233        })
234    }
235}
236
237struct RegistryItem {
238    item: syn::ItemStatic,
239}
240
241impl Parse for RegistryItem {
242    fn parse(input: ParseStream) -> syn::Result<Self> {
243        Ok(Self {
244            item: syn::ItemStatic::parse(input)?,
245        })
246    }
247}
248
249fn get_self_type_path(self_ty: &syn::Type) -> &syn::Path {
250    if let syn::Type::Path(type_path) = self_ty {
251        return &type_path.path;
252    }
253
254    let error_type = match self_ty {
255        syn::Type::Array(_) => "n array",
256        syn::Type::BareFn(_) => " function",
257        syn::Type::Group(_) => " group",
258        syn::Type::ImplTrait(_) => " trait impl",
259        syn::Type::Infer(_) => "n inferred type (_)",
260        syn::Type::Macro(_) => " macro",
261        syn::Type::Never(_) => " never type",
262        syn::Type::Paren(_) => " parenthesis",
263        syn::Type::Ptr(_) => " pointer",
264        syn::Type::Reference(_) => " reference",
265        syn::Type::Slice(_) => " slice",
266        syn::Type::TraitObject(_) => " trait object",
267        syn::Type::Tuple(_) => " tuple",
268        syn::Type::Verbatim(_) => "n unknown syntax",
269        _ => unreachable!(),
270    };
271
272    panic!(
273        "Cannot register implementation on a{}, expected a struct, enum, union or type alias.",
274        error_type
275    );
276}