1extern crate proc_macro;
16use proc_macro2::{Ident, Span, TokenStream};
17
18use quote::{format_ident, quote, quote_spanned, ToTokens};
19use syn::{
20 parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Generics, ItemTrait,
21 Result, Token, TraitItem, TraitItemFn, TypeParamBound,
22};
23
24fn unimplemented(x: &impl Spanned, things: &str) -> Error {
25 Error::new(
26 x.span(),
27 format!("{things} are not implemented for tinydyn"),
28 )
29}
30
31fn generics_unimplemented(generics: &Generics) -> Result<()> {
32 if let Some(where_clause) = &generics.where_clause {
33 return Err(unimplemented(where_clause, "where clauses"));
34 }
35 if !generics.params.is_empty() {
36 return Err(unimplemented(&generics.params, "generics"));
37 }
38 Ok(())
39}
40
41fn supertraits_unimplemented(supertraits: &Punctuated<TypeParamBound, Token![+]>) -> Result<()> {
42 if !supertraits.is_empty() {
43 return Err(unimplemented(&supertraits, "supertraits"));
44 }
45 Ok(())
46}
47
48fn unsafe_trait_unsupported(unsafety: &Option<Token![unsafe]>) -> Result<()> {
49 if let Some(unsafety) = unsafety {
50 return Err(unimplemented(unsafety, "unsafe traits"));
51 }
52 Ok(())
53}
54
55struct CommonNames {
57 tinydyn: Ident,
58 trait_ident: Ident,
59 trait_object: TokenStream,
60 private: TokenStream,
61 self_local: Ident,
62 meta_local: Ident,
63 vtable_ident: Ident,
64 concrete: TokenStream,
65}
66
67impl CommonNames {
68 fn new(trait_ident: Ident) -> Self {
69 let tinydyn = format_ident!("tinydyn");
70 let private = quote!(#tinydyn ::__private);
71 let self_local = Ident::new("self_", Span::mixed_site());
72 let meta_local = Ident::new("meta", Span::mixed_site());
73 let trait_object = quote!(dyn #trait_ident);
74 let vtable_ident = format_ident!("{trait_ident}Vtable");
75 let concrete = "Concrete".parse().unwrap();
76 Self {
77 tinydyn,
78 private,
79 self_local,
80 meta_local,
81 trait_ident,
82 trait_object,
83 vtable_ident,
84 concrete,
85 }
86 }
87}
88
89#[derive(Clone)]
90struct ReceiverArg<'a> {
91 type_: ReceiverType,
92 ident: &'a Ident,
93 elem: &'a syn::TypeReference,
94}
95
96impl<'a> ReceiverArg<'a> {
97 fn new(receiver: &'a syn::Receiver, names: &'a CommonNames) -> Result<Self> {
98 let syn::Type::Reference(elem) = &*receiver.ty else {
99 return Err(unimplemented(receiver, "non-reference methods"));
100 };
101 let type_;
102 let ident;
103 match &*elem.elem {
104 syn::Type::Path(path) if path.path.is_ident("Self") => {
105 ident = &names.self_local;
106 type_ = if elem.mutability.is_some() {
107 ReceiverType::MutableRef
108 } else {
109 ReceiverType::SharedRef
110 };
111 }
112 _ => return Err(unimplemented(receiver, "non-reference methods")),
113 };
114 Ok(Self { type_, elem, ident })
115 }
116}
117
118#[derive(Clone, Copy)]
119enum ReceiverType {
120 SharedRef,
122
123 MutableRef,
125}
126
127impl ToTokens for ReceiverArg<'_> {
128 fn to_tokens(&self, tokens: &mut TokenStream) {
129 self.elem.to_tokens(tokens)
130 }
131}
132
133impl From<&Option<Token![mut]>> for ReceiverType {
134 fn from(mutability: &Option<Token![mut]>) -> Self {
135 use ReceiverType::*;
136 if mutability.is_some() {
137 MutableRef
138 } else {
139 SharedRef
140 }
141 }
142}
143
144struct MethodArgInfo<'a> {
145 needs_bare_transmute: BareConversionNeeded,
146 orig_arg_type: &'a syn::Type,
147 bare_arg_type: Box<syn::Type>,
148 arg_ident: Ident,
149 comma: Option<Token![,]>,
150 colon: Option<Token![:]>,
151 receiver: Option<ReceiverArg<'a>>,
152}
153
154impl<'a> MethodArgInfo<'a> {
155 fn new(
156 arg_pair: syn::punctuated::Pair<&'a syn::FnArg, &'a Token![,]>,
157 names: &'a CommonNames,
158 arg_num: usize,
159 ) -> Result<Self> {
160 let arg = *arg_pair.value();
161 let comma = arg_pair.punct().map(|&&x| x);
162 let CommonNames {
163 private,
164 trait_object,
165 ..
166 } = names;
167 Ok(match arg {
168 syn::FnArg::Receiver(self_arg) => {
169 let receiver_arg = ReceiverArg::new(self_arg, names)?;
170 let pointer_to = match receiver_arg.type_ {
171 ReceiverType::SharedRef => quote!(*const),
172 ReceiverType::MutableRef => quote!(*mut),
173 };
174 MethodArgInfo {
175 arg_ident: receiver_arg.ident.clone(),
176 receiver: Some(receiver_arg),
177 colon: self_arg.colon_token,
178 needs_bare_transmute: BareConversionNeeded(false),
179 orig_arg_type: &*self_arg.ty,
180 bare_arg_type: Box::new(
181 syn::parse2(quote!(#private ::SelfPtr<#pointer_to #trait_object>)).unwrap(),
182 ),
183 comma,
184 }
185 }
186 syn::FnArg::Typed(pat_type) => {
187 let orig_arg_type = &pat_type.ty;
188 let (bare_arg_type, needs_bare_transmute) = to_bare_arg_type(&orig_arg_type)?;
189 MethodArgInfo {
190 arg_ident: Ident::new(&format!("arg{arg_num}"), Span::mixed_site()),
191 receiver: None,
192 colon: Some(pat_type.colon_token),
193 needs_bare_transmute,
194 orig_arg_type,
195 bare_arg_type,
196 comma,
197 }
198 }
199 })
200 }
201
202 fn into_bare_input_pair(self) -> syn::punctuated::Pair<syn::BareFnArg, Token![,]> {
203 let bare_arg = syn::BareFnArg {
204 attrs: Vec::new(),
205 name: Some((
206 self.arg_ident.clone(),
207 self.colon.unwrap_or_else(|| Token)),
208 )),
209 ty: *self.bare_arg_type,
210 };
211 syn::punctuated::Pair::new(bare_arg, self.comma)
212 }
213}
214
215struct BareConversionNeeded(pub bool);
216
217struct TraitMethod<'a> {
218 sig: &'a syn::Signature,
219 args: Vec<MethodArgInfo<'a>>,
220 bare_output: syn::ReturnType,
221 output_needs_transmute: BareConversionNeeded,
222 receiver: ReceiverArg<'a>,
223}
224
225impl<'a> TraitMethod<'a> {
226 fn new(sig: &'a syn::Signature, names: &'a CommonNames) -> Result<Self> {
227 let generics = &sig.generics;
228 for generic_param in &generics.params {
229 if !matches!(generic_param, syn::GenericParam::Lifetime(_)) {
230 return Err(unimplemented(
231 &generics.params,
232 "non-lifetime method generic parameter",
233 ));
234 }
235 }
236
237 if let Some(where_clause) = &generics.where_clause {
238 for predicate in &where_clause.predicates {
239 if !matches!(predicate, syn::WherePredicate::Lifetime(_)) {
240 return Err(unimplemented(
241 where_clause,
242 "non-lifetime method where clause",
243 ));
244 }
245 }
246 }
247 let mut method_receiver = None;
248 let args = sig
249 .inputs
250 .pairs()
251 .enumerate()
252 .map(|(arg_num, arg_pair)| {
253 let arg_info = MethodArgInfo::new(arg_pair, names, arg_num)?;
254 if let Some(arg_receiver) = &arg_info.receiver {
255 assert!(method_receiver.is_none(), "more than one receiver");
256 method_receiver = Some(arg_receiver.clone());
257 }
258 Ok(arg_info)
259 })
260 .collect::<Result<_>>()?;
261 let Some(receiver) = method_receiver else {
262 return Err(unimplemented(sig, "non-reference methods"));
263 };
264 let (bare_output, output_needs_transmute) = match &sig.output {
265 syn::ReturnType::Default => (syn::ReturnType::Default, BareConversionNeeded(false)),
266 syn::ReturnType::Type(arrow, ty) => {
267 let (bare_arg_type, need_convert) = to_bare_arg_type(&*ty)?;
268 (
269 syn::ReturnType::Type(arrow.clone(), bare_arg_type),
270 need_convert,
271 )
272 }
273 };
274 Ok(Self {
275 receiver,
276 sig,
277 args,
278 bare_output,
279 output_needs_transmute,
280 })
281 }
282
283 fn drain_bare_inputs(&mut self) -> syn::punctuated::Punctuated<syn::BareFnArg, Token![,]> {
284 self.args
285 .drain(..)
286 .map(|method_info| method_info.into_bare_input_pair())
287 .collect()
288 }
289}
290
291struct TinydynImplModule {
293 names: CommonNames,
294 vtable_entries: Vec<TokenStream>,
300 vtable_callers: Vec<TokenStream>,
301 static_vtable_type: TokenStream,
303 static_vtable_expr: TokenStream,
305 metadata_type: TokenStream,
307 metadata_getter: TokenStream,
310}
311
312impl ToTokens for TinydynImplModule {
313 fn to_tokens(&self, tokens: &mut TokenStream) {
314 tokens.extend([self.to_token_stream()])
315 }
316
317 fn into_token_stream(self) -> TokenStream
318 where
319 Self: Sized,
320 {
321 let Self {
322 static_vtable_type,
323 static_vtable_expr,
324 metadata_type,
325 metadata_getter,
326 vtable_callers,
327 vtable_entries,
328 names:
329 CommonNames {
330 vtable_ident,
331 trait_ident,
332 trait_object,
333 tinydyn,
334 private,
335 concrete,
336 ..
337 },
338 ..
339 } = self;
340
341 let mod_ident = format_ident!("__tinydyn_impl_{trait_ident}");
342 let newtype_ident = format_ident!("{trait_ident}Newtype");
343
344 quote!(mod #mod_ident {
345 use super::*;
346
347 #[derive(Copy, Clone)]
348 pub struct #vtable_ident {
349 #(#vtable_entries,)*
350 }
351
352 #[repr(transparent)]
353 pub struct #newtype_ident <T>(T);
354
355 unsafe impl #tinydyn ::PlainDyn for #trait_object {
356 type Metadata = #metadata_type;
357 type StaticVTable = #static_vtable_type;
358 type LocalNewtype<T> = #newtype_ident <T>;
359 }
360
361 unsafe impl #tinydyn ::DynTrait for #trait_object {
362 type Plain = #trait_object;
363 type RemoveSend = #trait_object;
364 type RemoveSync = #trait_object;
365 }
366
367 unsafe impl #tinydyn ::DynTrait for #trait_object + Send {
368 type Plain = #trait_object;
369 type RemoveSend = #trait_object;
370 type RemoveSync = #trait_object + Send;
371 }
372
373 unsafe impl #tinydyn ::DynTrait for #trait_object + Sync {
374 type Plain = #trait_object;
375 type RemoveSend = #trait_object + Sync;
376 type RemoveSync = #trait_object;
377 }
378
379 unsafe impl #tinydyn ::DynTrait for #trait_object + Send + Sync {
380 type Plain = #trait_object;
381 type RemoveSend = #trait_object + Sync;
382 type RemoveSync = #trait_object + Send;
383 }
384
385 unsafe impl<#concrete> #tinydyn ::BuildDynMeta<#trait_object> for #newtype_ident <#concrete>
386 where
387 #concrete: #trait_ident,
388 {
389 const STATIC_VTABLE: #static_vtable_type = #static_vtable_expr;
390
391 fn metadata() -> #metadata_type {
392 #metadata_getter
393 }
394 }
395 unsafe impl<T> #tinydyn ::Implements<#trait_object> for #newtype_ident <T> where T: #trait_ident {}
396 unsafe impl<T> #tinydyn ::Implements<#trait_object + Send> for #newtype_ident <T> where T: #trait_ident + Send {}
397 unsafe impl<T> #tinydyn ::Implements<#trait_object + Sync> for #newtype_ident <T> where T: #trait_ident + Sync {}
398 unsafe impl<T> #tinydyn ::Implements<#trait_object + Send + Sync> for #newtype_ident <T> where T: #trait_ident + Send + Sync {}
399
400 impl<Trait> #trait_ident for #private ::DynTarget<Trait>
401 where
402 Trait: ?Sized + #tinydyn ::DynTrait<Plain = #trait_object>,
403 {
404 #(#vtable_callers)*
405 }
406 })
407 }
408
409 fn to_token_stream(&self) -> TokenStream {
410 self.clone().into_token_stream()
411 }
412}
413
414impl TinydynImplModule {
415 fn new(trait_item: ItemTrait) -> Result<Self> {
416 let ItemTrait {
417 generics,
418 ident: trait_ident,
419 supertraits,
420 items,
421 unsafety,
422 ..
423 } = trait_item;
424 generics_unimplemented(&generics)?;
425 supertraits_unimplemented(&supertraits)?;
426 unsafe_trait_unsupported(&unsafety)?;
427
428 let names = CommonNames::new(trait_ident);
429 let CommonNames {
430 self_local,
431 private,
432 trait_ident,
433 vtable_ident,
434 concrete,
435 meta_local,
436 ..
437 } = &names;
438
439 let fn_items: Vec<TraitItemFn> = items
440 .into_iter()
441 .map(|item| match item {
442 TraitItem::Fn(fn_item) => Ok(fn_item),
443 _ => Err(unimplemented(&item, "non-function items")),
444 })
445 .collect::<Result<_>>()?;
446
447 let mut vtable_entries: Vec<TokenStream> = Vec::new();
452 let mut vtable_builders: Vec<TokenStream> = Vec::new();
453 let mut vtable_callers: Vec<TokenStream> = Vec::new();
454 let methods: Vec<TraitMethod> = fn_items
455 .iter()
456 .map(|fn_item| TraitMethod::new(&fn_item.sig, &names))
457 .collect::<Result<_>>()?;
458 for mut method in methods {
459 let sig = method.sig;
460 let entry_ident = sig.ident.clone();
461 vtable_builders.push(quote!(
462 #entry_ident: core::mem::transmute(
463 <#concrete as #trait_ident>:: #entry_ident as *const ())
464 ));
465 let erased_cons = match method.receiver.type_ {
466 ReceiverType::SharedRef => quote!(self_ref),
467 ReceiverType::MutableRef => quote!(self_mut),
468 };
469 let mut impl_sig = sig.clone();
470 let mut call_args = Vec::new();
471 let mut args_to_bare = Vec::new();
472 for (mut pair, arg) in impl_sig.inputs.pairs_mut().zip(&method.args) {
473 let &MethodArgInfo {
475 orig_arg_type,
476 ref bare_arg_type,
477 ref arg_ident,
478 ..
479 } = arg;
480 if let syn::FnArg::Typed(pat_type) = pair.value_mut() {
481 pat_type.pat = Box::new(syn::Pat::Ident(syn::PatIdent {
482 attrs: Vec::new(),
483 by_ref: None,
484 mutability: None,
485 ident: arg_ident.clone(),
486 subpat: None,
487 }));
488 }
489
490 if arg.needs_bare_transmute.0 {
493 args_to_bare.push(quote!(
494 let #arg_ident = #private
495 ::runtime_layout_verified_transmute::<#orig_arg_type, #bare_arg_type>
496 (#arg_ident);
497 ));
498 }
499 call_args.push(arg_ident.to_token_stream());
501 }
502
503 let bare_inputs: Punctuated<syn::BareFnArg, Token![,]> = method.drain_bare_inputs();
504
505 let mut vtable_call = quote!((#meta_local . #entry_ident)(#(#call_args,)*));
506 if let (syn::ReturnType::Type(_, out_ty), syn::ReturnType::Type(_, bare_ty)) =
508 (&sig.output, &method.bare_output)
509 {
510 if method.output_needs_transmute.0 {
511 let out_ty = &*out_ty;
512 vtable_call = quote!(#private ::runtime_layout_verified_transmute::<#bare_ty, #out_ty>(
513 #vtable_call));
514 }
515 }
516
517 let fn_pointer = syn::TypeBareFn {
518 lifetimes: None,
519 unsafety: sig.unsafety.clone(),
520 abi: sig.abi.clone(),
521 fn_token: sig.fn_token.clone(),
522 paren_token: sig.paren_token.clone(),
523 inputs: bare_inputs,
524 variadic: None,
525 output: method.bare_output,
526 };
527 vtable_entries.push(quote!(#entry_ident: #fn_pointer));
528 vtable_callers.push(quote!(
529 #[inline(always)]
530 #impl_sig {
531 let #meta_local = #private ::DynTarget::meta(self);
532 let #self_local = #private ::DynTarget:: #erased_cons (self);
533 unsafe {
534 #(#args_to_bare)*
535 #vtable_call
536 }
537 }
538 ));
539 }
540
541 let vtable_build_expr = quote!(
542 unsafe {
543 #vtable_ident {
544 #(#vtable_builders,)*
545 }
546 }
547 );
548 let static_vtable_type; let static_vtable_expr; let metadata_type; let metadata_getter; if fn_items.len() <= 1 {
554 static_vtable_type = quote!(#private ::InlineVTable);
555 static_vtable_expr = static_vtable_type.clone();
556 metadata_type = vtable_ident.to_token_stream();
557 metadata_getter = vtable_build_expr;
558 } else {
559 static_vtable_type = vtable_ident.to_token_stream();
560 static_vtable_expr = vtable_build_expr;
561 metadata_type = quote!(&'static #vtable_ident);
562 metadata_getter = quote!(&Self::STATIC_VTABLE);
563 }
564
565 Ok(Self {
566 vtable_entries,
567 vtable_callers,
568 static_vtable_type,
569 static_vtable_expr,
570 metadata_type,
571 metadata_getter,
572 names,
573 })
574
575 }
577}
578
579fn to_bare_arg_type(arg_type: &syn::Type) -> Result<(Box<syn::Type>, BareConversionNeeded)> {
581 use syn::fold::Fold;
582 struct ReplaceLifetimesWith<'a> {
583 replace_with: syn::Lifetime,
584 needed_replace: &'a mut bool,
585 }
586 impl Fold for ReplaceLifetimesWith<'_> {
587 fn fold_lifetime(&mut self, lt: syn::Lifetime) -> syn::Lifetime {
588 if lt == self.replace_with {
589 lt
590 } else {
591 *self.needed_replace = true;
592 self.replace_with.clone()
593 }
594 }
595 fn fold_type_reference(&mut self, mut i: syn::TypeReference) -> syn::TypeReference {
596 if !matches!(&i.lifetime, Some(lt) if *lt == self.replace_with) {
597 *self.needed_replace = true;
598 i.lifetime = Some(self.replace_with.clone());
599 }
600 i
601 }
602 }
603 let mut needed_replace = false;
604 let bare_type = Box::new(
605 ReplaceLifetimesWith {
606 replace_with: syn::parse_str("'static").unwrap(),
607 needed_replace: &mut needed_replace,
608 }
609 .fold_type(arg_type.clone()),
610 );
611 Ok((bare_type, BareConversionNeeded(needed_replace)))
612}
613
614fn tinydyn_mod_impl(trait_item: ItemTrait) -> Result<TokenStream> {
615 TinydynImplModule::new(trait_item).map(ToTokens::into_token_stream)
616}
617
618#[proc_macro_attribute]
620pub fn tinydyn(
621 params: proc_macro::TokenStream,
622 item: proc_macro::TokenStream,
623) -> proc_macro::TokenStream {
624 if let Some(first_tt) = params.into_iter().next() {
625 return quote_spanned!(
626 first_tt.span().into()=>
627 compile_error!("params must be empty");
628 )
629 .into();
630 }
631 let original_tokens = item.clone();
632 let input = parse_macro_input!(item as ItemTrait);
633 tinydyn_mod_impl(input)
634 .map(move |mod_impl| {
635 let mut mod_impl = proc_macro::TokenStream::from(mod_impl);
636 mod_impl.extend([
637 "#[deny(elided_lifetimes_in_paths)]"
638 .parse::<proc_macro::TokenStream>()
639 .unwrap()
640 .into(),
641 original_tokens,
642 ]);
643 mod_impl
644 })
645 .unwrap_or_else(|e| e.into_compile_error().into())
646}