tinydyn_derive/
lib.rs

1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15extern crate proc_macro;
16use proc_macro2::{Ident, Span, TokenStream};
17
18use quote::{format_ident, quote, quote_spanned, ToTokens};
19use syn::{
20    parse_macro_input, punctuated::Punctuated, spanned::Spanned, Error, Generics, ItemTrait,
21    Result, Token, TraitItem, TraitItemFn, TypeParamBound,
22};
23
24fn unimplemented(x: &impl Spanned, things: &str) -> Error {
25    Error::new(
26        x.span(),
27        format!("{things} are not implemented for tinydyn"),
28    )
29}
30
31fn generics_unimplemented(generics: &Generics) -> Result<()> {
32    if let Some(where_clause) = &generics.where_clause {
33        return Err(unimplemented(where_clause, "where clauses"));
34    }
35    if !generics.params.is_empty() {
36        return Err(unimplemented(&generics.params, "generics"));
37    }
38    Ok(())
39}
40
41fn supertraits_unimplemented(supertraits: &Punctuated<TypeParamBound, Token![+]>) -> Result<()> {
42    if !supertraits.is_empty() {
43        return Err(unimplemented(&supertraits, "supertraits"));
44    }
45    Ok(())
46}
47
48fn unsafe_trait_unsupported(unsafety: &Option<Token![unsafe]>) -> Result<()> {
49    if let Some(unsafety) = unsafety {
50        return Err(unimplemented(unsafety, "unsafe traits"));
51    }
52    Ok(())
53}
54
55// TODO: refactor to properly separate out parsing logic and token generation logic.
56struct CommonNames {
57    tinydyn: Ident,
58    trait_ident: Ident,
59    trait_object: TokenStream,
60    private: TokenStream,
61    self_local: Ident,
62    meta_local: Ident,
63    vtable_ident: Ident,
64    concrete: TokenStream,
65}
66
67impl CommonNames {
68    fn new(trait_ident: Ident) -> Self {
69        let tinydyn = format_ident!("tinydyn");
70        let private = quote!(#tinydyn ::__private);
71        let self_local = Ident::new("self_", Span::mixed_site());
72        let meta_local = Ident::new("meta", Span::mixed_site());
73        let trait_object = quote!(dyn #trait_ident);
74        let vtable_ident = format_ident!("{trait_ident}Vtable");
75        let concrete = "Concrete".parse().unwrap();
76        Self {
77            tinydyn,
78            private,
79            self_local,
80            meta_local,
81            trait_ident,
82            trait_object,
83            vtable_ident,
84            concrete,
85        }
86    }
87}
88
89#[derive(Clone)]
90struct ReceiverArg<'a> {
91    type_: ReceiverType,
92    ident: &'a Ident,
93    elem: &'a syn::TypeReference,
94}
95
96impl<'a> ReceiverArg<'a> {
97    fn new(receiver: &'a syn::Receiver, names: &'a CommonNames) -> Result<Self> {
98        let syn::Type::Reference(elem) = &*receiver.ty else {
99            return Err(unimplemented(receiver, "non-reference methods"));
100        };
101        let type_;
102        let ident;
103        match &*elem.elem {
104            syn::Type::Path(path) if path.path.is_ident("Self") => {
105                ident = &names.self_local;
106                type_ = if elem.mutability.is_some() {
107                    ReceiverType::MutableRef
108                } else {
109                    ReceiverType::SharedRef
110                };
111            }
112            _ => return Err(unimplemented(receiver, "non-reference methods")),
113        };
114        Ok(Self { type_, elem, ident })
115    }
116}
117
118#[derive(Clone, Copy)]
119enum ReceiverType {
120    /// `&self`
121    SharedRef,
122
123    /// `&mut self`
124    MutableRef,
125}
126
127impl ToTokens for ReceiverArg<'_> {
128    fn to_tokens(&self, tokens: &mut TokenStream) {
129        self.elem.to_tokens(tokens)
130    }
131}
132
133impl From<&Option<Token![mut]>> for ReceiverType {
134    fn from(mutability: &Option<Token![mut]>) -> Self {
135        use ReceiverType::*;
136        if mutability.is_some() {
137            MutableRef
138        } else {
139            SharedRef
140        }
141    }
142}
143
144struct MethodArgInfo<'a> {
145    needs_bare_transmute: BareConversionNeeded,
146    orig_arg_type: &'a syn::Type,
147    bare_arg_type: Box<syn::Type>,
148    arg_ident: Ident,
149    comma: Option<Token![,]>,
150    colon: Option<Token![:]>,
151    receiver: Option<ReceiverArg<'a>>,
152}
153
154impl<'a> MethodArgInfo<'a> {
155    fn new(
156        arg_pair: syn::punctuated::Pair<&'a syn::FnArg, &'a Token![,]>,
157        names: &'a CommonNames,
158        arg_num: usize,
159    ) -> Result<Self> {
160        let arg = *arg_pair.value();
161        let comma = arg_pair.punct().map(|&&x| x);
162        let CommonNames {
163            private,
164            trait_object,
165            ..
166        } = names;
167        Ok(match arg {
168            syn::FnArg::Receiver(self_arg) => {
169                let receiver_arg = ReceiverArg::new(self_arg, names)?;
170                let pointer_to = match receiver_arg.type_ {
171                    ReceiverType::SharedRef => quote!(*const),
172                    ReceiverType::MutableRef => quote!(*mut),
173                };
174                MethodArgInfo {
175                    arg_ident: receiver_arg.ident.clone(),
176                    receiver: Some(receiver_arg),
177                    colon: self_arg.colon_token,
178                    needs_bare_transmute: BareConversionNeeded(false),
179                    orig_arg_type: &*self_arg.ty,
180                    bare_arg_type: Box::new(
181                        syn::parse2(quote!(#private ::SelfPtr<#pointer_to #trait_object>)).unwrap(),
182                    ),
183                    comma,
184                }
185            }
186            syn::FnArg::Typed(pat_type) => {
187                let orig_arg_type = &pat_type.ty;
188                let (bare_arg_type, needs_bare_transmute) = to_bare_arg_type(&orig_arg_type)?;
189                MethodArgInfo {
190                    arg_ident: Ident::new(&format!("arg{arg_num}"), Span::mixed_site()),
191                    receiver: None,
192                    colon: Some(pat_type.colon_token),
193                    needs_bare_transmute,
194                    orig_arg_type,
195                    bare_arg_type,
196                    comma,
197                }
198            }
199        })
200    }
201
202    fn into_bare_input_pair(self) -> syn::punctuated::Pair<syn::BareFnArg, Token![,]> {
203        let bare_arg = syn::BareFnArg {
204            attrs: Vec::new(),
205            name: Some((
206                self.arg_ident.clone(),
207                self.colon.unwrap_or_else(|| Token![:](Span::call_site())),
208            )),
209            ty: *self.bare_arg_type,
210        };
211        syn::punctuated::Pair::new(bare_arg, self.comma)
212    }
213}
214
215struct BareConversionNeeded(pub bool);
216
217struct TraitMethod<'a> {
218    sig: &'a syn::Signature,
219    args: Vec<MethodArgInfo<'a>>,
220    bare_output: syn::ReturnType,
221    output_needs_transmute: BareConversionNeeded,
222    receiver: ReceiverArg<'a>,
223}
224
225impl<'a> TraitMethod<'a> {
226    fn new(sig: &'a syn::Signature, names: &'a CommonNames) -> Result<Self> {
227        let generics = &sig.generics;
228        for generic_param in &generics.params {
229            if !matches!(generic_param, syn::GenericParam::Lifetime(_)) {
230                return Err(unimplemented(
231                    &generics.params,
232                    "non-lifetime method generic parameter",
233                ));
234            }
235        }
236
237        if let Some(where_clause) = &generics.where_clause {
238            for predicate in &where_clause.predicates {
239                if !matches!(predicate, syn::WherePredicate::Lifetime(_)) {
240                    return Err(unimplemented(
241                        where_clause,
242                        "non-lifetime method where clause",
243                    ));
244                }
245            }
246        }
247        let mut method_receiver = None;
248        let args = sig
249            .inputs
250            .pairs()
251            .enumerate()
252            .map(|(arg_num, arg_pair)| {
253                let arg_info = MethodArgInfo::new(arg_pair, names, arg_num)?;
254                if let Some(arg_receiver) = &arg_info.receiver {
255                    assert!(method_receiver.is_none(), "more than one receiver");
256                    method_receiver = Some(arg_receiver.clone());
257                }
258                Ok(arg_info)
259            })
260            .collect::<Result<_>>()?;
261        let Some(receiver) = method_receiver else {
262            return Err(unimplemented(sig, "non-reference methods"));
263        };
264        let (bare_output, output_needs_transmute) = match &sig.output {
265            syn::ReturnType::Default => (syn::ReturnType::Default, BareConversionNeeded(false)),
266            syn::ReturnType::Type(arrow, ty) => {
267                let (bare_arg_type, need_convert) = to_bare_arg_type(&*ty)?;
268                (
269                    syn::ReturnType::Type(arrow.clone(), bare_arg_type),
270                    need_convert,
271                )
272            }
273        };
274        Ok(Self {
275            receiver,
276            sig,
277            args,
278            bare_output,
279            output_needs_transmute,
280        })
281    }
282
283    fn drain_bare_inputs(&mut self) -> syn::punctuated::Punctuated<syn::BareFnArg, Token![,]> {
284        self.args
285            .drain(..)
286            .map(|method_info| method_info.into_bare_input_pair())
287            .collect()
288    }
289}
290
291/// All of the data necessary to build the module that impls for `tinydyn`.
292struct TinydynImplModule {
293    names: CommonNames,
294    // trait_ident: Ident,
295
296    // trait_object: TokenStream,
297    // private: TokenStream,
298    // vtable_build_expr: TokenStream,
299    vtable_entries: Vec<TokenStream>,
300    vtable_callers: Vec<TokenStream>,
301    /// This is statically alloc'd for every (trait, concrete).
302    static_vtable_type: TokenStream,
303    /// This builds the `static_vtable_type` for this (trait, concrete).
304    static_vtable_expr: TokenStream,
305    /// This extra data is carried along in DynPtr.
306    metadata_type: TokenStream,
307    /// When building a wide pointer, this gets the metadata.
308    /// This might build a vtable or get a static one.
309    metadata_getter: TokenStream,
310}
311
312impl ToTokens for TinydynImplModule {
313    fn to_tokens(&self, tokens: &mut TokenStream) {
314        tokens.extend([self.to_token_stream()])
315    }
316
317    fn into_token_stream(self) -> TokenStream
318    where
319        Self: Sized,
320    {
321        let Self {
322            static_vtable_type,
323            static_vtable_expr,
324            metadata_type,
325            metadata_getter,
326            vtable_callers,
327            vtable_entries,
328            names:
329                CommonNames {
330                    vtable_ident,
331                    trait_ident,
332                    trait_object,
333                    tinydyn,
334                    private,
335                    concrete,
336                    ..
337                },
338            ..
339        } = self;
340
341        let mod_ident = format_ident!("__tinydyn_impl_{trait_ident}");
342        let newtype_ident = format_ident!("{trait_ident}Newtype");
343
344        quote!(mod #mod_ident {
345            use super::*;
346
347            #[derive(Copy, Clone)]
348            pub struct #vtable_ident {
349                #(#vtable_entries,)*
350            }
351
352            #[repr(transparent)]
353            pub struct #newtype_ident <T>(T);
354
355            unsafe impl #tinydyn ::PlainDyn for #trait_object {
356                type Metadata = #metadata_type;
357                type StaticVTable = #static_vtable_type;
358                type LocalNewtype<T> = #newtype_ident <T>;
359            }
360
361            unsafe impl #tinydyn ::DynTrait for #trait_object {
362                type Plain = #trait_object;
363                type RemoveSend = #trait_object;
364                type RemoveSync = #trait_object;
365            }
366
367            unsafe impl #tinydyn ::DynTrait for #trait_object + Send {
368                type Plain = #trait_object;
369                type RemoveSend = #trait_object;
370                type RemoveSync = #trait_object + Send;
371            }
372
373            unsafe impl #tinydyn ::DynTrait for #trait_object + Sync {
374                type Plain = #trait_object;
375                type RemoveSend = #trait_object + Sync;
376                type RemoveSync = #trait_object;
377            }
378
379            unsafe impl #tinydyn ::DynTrait for #trait_object + Send + Sync {
380                type Plain = #trait_object;
381                type RemoveSend = #trait_object + Sync;
382                type RemoveSync = #trait_object + Send;
383            }
384
385            unsafe impl<#concrete> #tinydyn ::BuildDynMeta<#trait_object> for #newtype_ident <#concrete>
386            where
387                #concrete: #trait_ident,
388            {
389                const STATIC_VTABLE: #static_vtable_type = #static_vtable_expr;
390
391                fn metadata() -> #metadata_type {
392                    #metadata_getter
393                }
394            }
395            unsafe impl<T> #tinydyn ::Implements<#trait_object> for #newtype_ident <T> where T: #trait_ident {}
396            unsafe impl<T> #tinydyn ::Implements<#trait_object + Send> for #newtype_ident <T> where T: #trait_ident + Send {}
397            unsafe impl<T> #tinydyn ::Implements<#trait_object + Sync> for #newtype_ident <T> where T: #trait_ident + Sync {}
398            unsafe impl<T> #tinydyn ::Implements<#trait_object + Send + Sync> for #newtype_ident <T> where T: #trait_ident + Send + Sync {}
399
400            impl<Trait> #trait_ident for #private ::DynTarget<Trait>
401            where
402                Trait: ?Sized + #tinydyn ::DynTrait<Plain = #trait_object>,
403            {
404                #(#vtable_callers)*
405            }
406        })
407    }
408
409    fn to_token_stream(&self) -> TokenStream {
410        self.clone().into_token_stream()
411    }
412}
413
414impl TinydynImplModule {
415    fn new(trait_item: ItemTrait) -> Result<Self> {
416        let ItemTrait {
417            generics,
418            ident: trait_ident,
419            supertraits,
420            items,
421            unsafety,
422            ..
423        } = trait_item;
424        generics_unimplemented(&generics)?;
425        supertraits_unimplemented(&supertraits)?;
426        unsafe_trait_unsupported(&unsafety)?;
427
428        let names = CommonNames::new(trait_ident);
429        let CommonNames {
430            self_local,
431            private,
432            trait_ident,
433            vtable_ident,
434            concrete,
435            meta_local,
436            ..
437        } = &names;
438
439        let fn_items: Vec<TraitItemFn> = items
440            .into_iter()
441            .map(|item| match item {
442                TraitItem::Fn(fn_item) => Ok(fn_item),
443                _ => Err(unimplemented(&item, "non-function items")),
444            })
445            .collect::<Result<_>>()?;
446
447        // vtable:
448        // - entries: the function pointer fields in the vtable
449        // - builders: the field initializers for the concrete type's vtable
450        // - callers: the trait impl methods on DynTarget that call a trait method from vtable
451        let mut vtable_entries: Vec<TokenStream> = Vec::new();
452        let mut vtable_builders: Vec<TokenStream> = Vec::new();
453        let mut vtable_callers: Vec<TokenStream> = Vec::new();
454        let methods: Vec<TraitMethod> = fn_items
455            .iter()
456            .map(|fn_item| TraitMethod::new(&fn_item.sig, &names))
457            .collect::<Result<_>>()?;
458        for mut method in methods {
459            let sig = method.sig;
460            let entry_ident = sig.ident.clone();
461            vtable_builders.push(quote!(
462                #entry_ident: core::mem::transmute(
463                    <#concrete as #trait_ident>:: #entry_ident as *const ())
464            ));
465            let erased_cons = match method.receiver.type_ {
466                ReceiverType::SharedRef => quote!(self_ref),
467                ReceiverType::MutableRef => quote!(self_mut),
468            };
469            let mut impl_sig = sig.clone();
470            let mut call_args = Vec::new();
471            let mut args_to_bare = Vec::new();
472            for (mut pair, arg) in impl_sig.inputs.pairs_mut().zip(&method.args) {
473                // Replace with our custom argument name
474                let &MethodArgInfo {
475                    orig_arg_type,
476                    ref bare_arg_type,
477                    ref arg_ident,
478                    ..
479                } = arg;
480                if let syn::FnArg::Typed(pat_type) = pair.value_mut() {
481                    pat_type.pat = Box::new(syn::Pat::Ident(syn::PatIdent {
482                        attrs: Vec::new(),
483                        by_ref: None,
484                        mutability: None,
485                        ident: arg_ident.clone(),
486                        subpat: None,
487                    }));
488                }
489
490                // Erase lifetimes and prepare for the bare fn (pointer)
491                // `transmute` doesn't work with generic arguments, but `transmute_copy` does.
492                if arg.needs_bare_transmute.0 {
493                    args_to_bare.push(quote!(
494                        let #arg_ident = #private
495                            ::runtime_layout_verified_transmute::<#orig_arg_type, #bare_arg_type>
496                            (#arg_ident);
497                    ));
498                }
499                // The argument in the vtable method call
500                call_args.push(arg_ident.to_token_stream());
501            }
502
503            let bare_inputs: Punctuated<syn::BareFnArg, Token![,]> = method.drain_bare_inputs();
504
505            let mut vtable_call = quote!((#meta_local . #entry_ident)(#(#call_args,)*));
506            // don't forget to transmute the output type if it needs it
507            if let (syn::ReturnType::Type(_, out_ty), syn::ReturnType::Type(_, bare_ty)) =
508                (&sig.output, &method.bare_output)
509            {
510                if method.output_needs_transmute.0 {
511                    let out_ty = &*out_ty;
512                    vtable_call = quote!(#private ::runtime_layout_verified_transmute::<#bare_ty, #out_ty>(
513                            #vtable_call));
514                }
515            }
516
517            let fn_pointer = syn::TypeBareFn {
518                lifetimes: None,
519                unsafety: sig.unsafety.clone(),
520                abi: sig.abi.clone(),
521                fn_token: sig.fn_token.clone(),
522                paren_token: sig.paren_token.clone(),
523                inputs: bare_inputs,
524                variadic: None,
525                output: method.bare_output,
526            };
527            vtable_entries.push(quote!(#entry_ident: #fn_pointer));
528            vtable_callers.push(quote!(
529                #[inline(always)]
530                #impl_sig {
531                    let #meta_local = #private ::DynTarget::meta(self);
532                    let #self_local = #private ::DynTarget:: #erased_cons (self);
533                    unsafe {
534                        #(#args_to_bare)*
535                        #vtable_call
536                    }
537                }
538            ));
539        }
540
541        let vtable_build_expr = quote!(
542            unsafe {
543                #vtable_ident {
544                    #(#vtable_builders,)*
545                }
546            }
547        );
548        let static_vtable_type; // This is statically alloc'd for every (trait, concrete).
549        let static_vtable_expr; // This builds the above.
550        let metadata_type; // This extra data is carried along in DynPtr.
551        let metadata_getter; // When building a wide pointer, this gets the metadata.
552
553        if fn_items.len() <= 1 {
554            static_vtable_type = quote!(#private ::InlineVTable);
555            static_vtable_expr = static_vtable_type.clone();
556            metadata_type = vtable_ident.to_token_stream();
557            metadata_getter = vtable_build_expr;
558        } else {
559            static_vtable_type = vtable_ident.to_token_stream();
560            static_vtable_expr = vtable_build_expr;
561            metadata_type = quote!(&'static #vtable_ident);
562            metadata_getter = quote!(&Self::STATIC_VTABLE);
563        }
564
565        Ok(Self {
566            vtable_entries,
567            vtable_callers,
568            static_vtable_type,
569            static_vtable_expr,
570            metadata_type,
571            metadata_getter,
572            names,
573        })
574
575        // self.static_vtable_type.to_tokens()
576    }
577}
578
579/// Returns (bare fn type, whether it needed the conversion)
580fn to_bare_arg_type(arg_type: &syn::Type) -> Result<(Box<syn::Type>, BareConversionNeeded)> {
581    use syn::fold::Fold;
582    struct ReplaceLifetimesWith<'a> {
583        replace_with: syn::Lifetime,
584        needed_replace: &'a mut bool,
585    }
586    impl Fold for ReplaceLifetimesWith<'_> {
587        fn fold_lifetime(&mut self, lt: syn::Lifetime) -> syn::Lifetime {
588            if lt == self.replace_with {
589                lt
590            } else {
591                *self.needed_replace = true;
592                self.replace_with.clone()
593            }
594        }
595        fn fold_type_reference(&mut self, mut i: syn::TypeReference) -> syn::TypeReference {
596            if !matches!(&i.lifetime, Some(lt) if *lt == self.replace_with) {
597                *self.needed_replace = true;
598                i.lifetime = Some(self.replace_with.clone());
599            }
600            i
601        }
602    }
603    let mut needed_replace = false;
604    let bare_type = Box::new(
605        ReplaceLifetimesWith {
606            replace_with: syn::parse_str("'static").unwrap(),
607            needed_replace: &mut needed_replace,
608        }
609        .fold_type(arg_type.clone()),
610    );
611    Ok((bare_type, BareConversionNeeded(needed_replace)))
612}
613
614fn tinydyn_mod_impl(trait_item: ItemTrait) -> Result<TokenStream> {
615    TinydynImplModule::new(trait_item).map(ToTokens::into_token_stream)
616}
617
618/// This trait is `tinydyn`-compatible.
619#[proc_macro_attribute]
620pub fn tinydyn(
621    params: proc_macro::TokenStream,
622    item: proc_macro::TokenStream,
623) -> proc_macro::TokenStream {
624    if let Some(first_tt) = params.into_iter().next() {
625        return quote_spanned!(
626            first_tt.span().into()=>
627            compile_error!("params must be empty");
628        )
629        .into();
630    }
631    let original_tokens = item.clone();
632    let input = parse_macro_input!(item as ItemTrait);
633    tinydyn_mod_impl(input)
634        .map(move |mod_impl| {
635            let mut mod_impl = proc_macro::TokenStream::from(mod_impl);
636            mod_impl.extend([
637                "#[deny(elided_lifetimes_in_paths)]"
638                    .parse::<proc_macro::TokenStream>()
639                    .unwrap()
640                    .into(),
641                original_tokens,
642            ]);
643            mod_impl
644        })
645        .unwrap_or_else(|e| e.into_compile_error().into())
646}