1use quote::quote;
2use syn::parse::{Parse, ParseStream};
3use syn::Ident;
4
5#[proc_macro_attribute]
36pub fn register(
37 attr: proc_macro::TokenStream,
38 item: proc_macro::TokenStream,
39) -> proc_macro::TokenStream {
40 let constructor_fn = if attr.is_empty() {
42 None
43 } else {
44 Some(syn::parse_macro_input!(attr as RegisterAttribute))
45 };
46
47 let has_constructor = constructor_fn.is_some();
48 let has_constructor = quote! { #has_constructor };
49
50 let constructor_fn_call_str = if let Some(cfn) = constructor_fn {
51 let ident = cfn.constructor_fn_ident;
52 quote! {
53 Some(Box::new(Self::#ident()))
54 }
55 } else {
56 quote! {
57 None
58 }
59 };
60
61 let item_clone = item.clone();
62
63 let parsed_item = syn::parse_macro_input!(item as RegisterItem);
64 let item_impl = parsed_item.item;
65
66 let (trait_not, trait_path, _) = item_impl
67 .trait_
68 .expect("Can only register an implementation of a trait, 'impl <Trait> for <Type>'.");
69 assert!(
70 trait_not.is_none(),
71 "Cannot register inverted impl trait: 'impl !Trait for Type'."
72 );
73
74 let trait_ident = trait_path
75 .require_ident()
76 .expect("Expected trait in impl block to have an identifier.");
77 let trait_name = format!("{trait_ident}");
78
79 let type_path = get_self_type_path(&item_impl.self_ty);
80 let type_ident = type_path
81 .require_ident()
82 .expect("Expected type in impl block to have an identifier.");
83 let type_name = format!("{type_ident}");
84
85 let register_static_ident =
86 syn::parse_str::<syn::Ident>(format!("{}_{}__Register", type_ident, trait_ident).as_ref())
87 .expect("Unable to create identifier");
88 let register_static_fn_ident = syn::parse_str::<syn::Ident>(
89 format!("{}_{}__RegisterFn", type_ident, trait_ident).as_ref(),
90 )
91 .expect("Unable to create identifier");
92
93 let mut result: proc_macro::TokenStream = quote! {
94 impl traitreg::RegisteredImpl<Box<dyn #trait_path>> for #type_path {
95 const INSTANCIATE: fn() -> Option<Box<dyn #trait_path>> = || { #constructor_fn_call_str };
96 const HAS_CONSTRUCTOR: bool = #has_constructor;
97 const NAME: &'static str = #type_name;
98 const PATH: &'static str = stringify!(#type_path);
99 const FILE: &'static str = core::file!() ;
100 const MODULE_PATH: &'static str = core::module_path!();
101 const TRAIT_NAME: &'static str = #trait_name;
102 }
103
104 #[used]
105 #[cfg_attr(any(target_os = "linux", target_os = "android"), link_section = ".init_array.10000")]
106 #[cfg_attr(target_os = "freebsd", link_section = ".init_array.10000")]
107 #[cfg_attr(target_os = "netbsd", link_section = ".init_array.10000")]
108 #[cfg_attr(target_os = "openbsd", link_section = ".init_array.10000")]
109 #[cfg_attr(target_os = "dragonfly", link_section = ".init_array.10000")]
110 #[cfg_attr(target_os = "illumos", link_section = ".init_array.10000")]
111 #[cfg_attr(target_os = "haiku", link_section = ".init_array.10000")]
112 #[cfg_attr(target_vendor = "apple", link_section = "__DATA,__mod_init_func")]
113 #[cfg_attr(windows, link_section = ".CRT$XCT")]
114 static #register_static_ident: extern fn() = {
115 extern fn #register_static_fn_ident() {
116 traitreg::__register_impl::<Box<dyn #trait_path>, #type_path>();
117 }
118 #register_static_fn_ident
119 };
120 }.into();
121
122 result.extend(item_clone.clone());
123
124 result
125}
126
127#[proc_macro_attribute]
136pub fn registry(
137 attr: proc_macro::TokenStream,
138 item: proc_macro::TokenStream,
139) -> proc_macro::TokenStream {
140 let registry_attr = syn::parse_macro_input!(attr as RegistryAttribute);
141 let registry_item = syn::parse_macro_input!(item as RegistryItem);
142
143 let trait_ident = registry_attr.trait_ident;
144 let item = registry_item.item;
145
146 let trait_name = format!("{trait_ident}");
147 let item_ident = item.ident;
148 let storage_ident = syn::parse_str::<syn::Ident>(format!("{}__STORAGE", item_ident).as_ref())
149 .expect("Unable to create identifier");
150 let wrapper_struct_ident =
151 syn::parse_str::<syn::Ident>(format!("{}__TraitReg", item_ident).as_ref())
152 .expect("Unable to create identifier");
153 let build_static_ident =
154 syn::parse_str::<syn::Ident>(format!("{}__Build", item_ident).as_ref())
155 .expect("Unable to create identifier");
156 let build_static_fn_ident =
157 syn::parse_str::<syn::Ident>(format!("{}__BuildFn", item_ident).as_ref())
158 .expect("Unable to create identifier");
159
160 quote! {
161 static mut #storage_ident: Option<traitreg::TraitRegStorage<Box<dyn #trait_ident>>> = None;
162
163 static #item_ident: #wrapper_struct_ident = #wrapper_struct_ident {};
164
165 struct #wrapper_struct_ident;
166
167 impl ::core::ops::Deref for #wrapper_struct_ident {
168 type Target = traitreg::TraitRegStorage<Box<dyn #trait_ident>>;
169 fn deref(&self) -> &'static traitreg::TraitRegStorage<Box<dyn #trait_ident>> {
170 unsafe {
171 #storage_ident.as_ref().unwrap()
172 }
173 }
174 }
175
176 #[used]
177 #[cfg_attr(any(target_os = "linux", target_os = "android"), link_section = ".init_array.20000")]
178 #[cfg_attr(target_os = "freebsd", link_section = ".init_array.20000")]
179 #[cfg_attr(target_os = "netbsd", link_section = ".init_array.20000")]
180 #[cfg_attr(target_os = "openbsd", link_section = ".init_array.20000")]
181 #[cfg_attr(target_os = "dragonfly", link_section = ".init_array.20000")]
182 #[cfg_attr(target_os = "illumos", link_section = ".init_array.20000")]
183 #[cfg_attr(target_os = "haiku", link_section = ".init_array.20000")]
184 #[cfg_attr(target_vendor = "apple", link_section = "__DATA,__mod_init_func")]
185 #[cfg_attr(windows, link_section = ".CRT$XCU")]
186 static #build_static_ident: extern fn() = {
187 extern fn #build_static_fn_ident() {
188 let storage = traitreg::TraitRegStorage::<Box<dyn #trait_ident>>::__new(#trait_name);
189
190 unsafe {
191 #storage_ident = Some(storage)
192 }
193 }
194 #build_static_fn_ident
195 };
196 }.into()
197}
198
199#[derive(Debug)]
200struct RegisterAttribute {
201 constructor_fn_ident: Ident,
202}
203
204impl Parse for RegisterAttribute {
205 fn parse(input: ParseStream) -> syn::Result<Self> {
206 Ok(Self {
207 constructor_fn_ident: Ident::parse(input)?,
208 })
209 }
210}
211
212struct RegisterItem {
213 item: syn::ItemImpl,
214}
215
216impl Parse for RegisterItem {
217 fn parse(input: ParseStream) -> syn::Result<Self> {
218 Ok(Self {
219 item: syn::ItemImpl::parse(input)?,
220 })
221 }
222}
223
224#[derive(Debug)]
225struct RegistryAttribute {
226 trait_ident: Ident,
227}
228
229impl Parse for RegistryAttribute {
230 fn parse(input: ParseStream) -> syn::Result<Self> {
231 Ok(Self {
232 trait_ident: Ident::parse(input)?,
233 })
234 }
235}
236
237struct RegistryItem {
238 item: syn::ItemStatic,
239}
240
241impl Parse for RegistryItem {
242 fn parse(input: ParseStream) -> syn::Result<Self> {
243 Ok(Self {
244 item: syn::ItemStatic::parse(input)?,
245 })
246 }
247}
248
249fn get_self_type_path(self_ty: &syn::Type) -> &syn::Path {
250 if let syn::Type::Path(type_path) = self_ty {
251 return &type_path.path;
252 }
253
254 let error_type = match self_ty {
255 syn::Type::Array(_) => "n array",
256 syn::Type::BareFn(_) => " function",
257 syn::Type::Group(_) => " group",
258 syn::Type::ImplTrait(_) => " trait impl",
259 syn::Type::Infer(_) => "n inferred type (_)",
260 syn::Type::Macro(_) => " macro",
261 syn::Type::Never(_) => " never type",
262 syn::Type::Paren(_) => " parenthesis",
263 syn::Type::Ptr(_) => " pointer",
264 syn::Type::Reference(_) => " reference",
265 syn::Type::Slice(_) => " slice",
266 syn::Type::TraitObject(_) => " trait object",
267 syn::Type::Tuple(_) => " tuple",
268 syn::Type::Verbatim(_) => "n unknown syntax",
269 _ => unreachable!(),
270 };
271
272 panic!(
273 "Cannot register implementation on a{}, expected a struct, enum, union or type alias.",
274 error_type
275 );
276}