static_dispatch_macros/
lib.rs

1use std::ops::Not;
2
3use proc_macro::TokenStream;
4use quote::{ToTokens, format_ident, quote};
5use syn::{
6    Error, FnArg, GenericParam, Generics, Ident, ItemEnum, ItemTrait, Path, PathArguments,
7    ReturnType, TraitItem, TraitItemFn, Type, TypeGenerics, TypeReference,
8};
9
10#[proc_macro_attribute]
11/// See the module for documentation.
12pub fn dispatch(attr: TokenStream, item: TokenStream) -> TokenStream {
13    let item = proc_macro2::TokenStream::from(item);
14
15    let output = if let Ok(input_trait) = syn::parse2(item.clone()) {
16        dispatch_trait(attr, input_trait)
17    } else if let Ok(input_trait) = syn::parse2(item.clone()) {
18        dispatch_enum(attr, input_trait)
19    } else {
20        Error::new_spanned(&item, "Could not parse as trait or enum").to_compile_error()
21    };
22
23    quote! {
24        #item
25        #output
26    }
27    .into()
28}
29
30fn is_self_type(ty: &Type) -> bool {
31    match ty {
32        // Check for plain `Self`
33        Type::Path(type_path) => {
34            type_path.qself.is_none()
35                && type_path.path.segments.len() == 1
36                && type_path.path.segments[0].ident == "Self"
37                && matches!(type_path.path.segments[0].arguments, PathArguments::None)
38        }
39        // Check for `&Self` or `&mut Self`
40        Type::Reference(TypeReference { elem, .. }) => is_self_type(elem),
41        _ => false,
42    }
43}
44
45fn is_valid_self(arg: Option<&FnArg>) -> bool {
46    let Some(FnArg::Receiver(receiver)) = arg else {
47        return false;
48    };
49    receiver.colon_token.is_none() || is_self_type(&receiver.ty)
50}
51
52fn generics_for_method(generics: &Generics) -> proc_macro2::TokenStream {
53    let mut generics = generics.params.iter().filter_map(|generic| match generic {
54        GenericParam::Lifetime(_) => None,
55        GenericParam::Const(const_generic) => Some(&const_generic.ident),
56        GenericParam::Type(type_generic) => Some(&type_generic.ident),
57    });
58    let Some(first) = generics.next() else {
59        return proc_macro2::TokenStream::new();
60    };
61    let mut res = quote! {::<#first};
62    for generic in generics {
63        quote! {, #generic}.to_tokens(&mut res);
64    }
65    quote! {>}.to_tokens(&mut res);
66    res
67}
68
69fn create_trait_item_macro(
70    trait_name: &Ident,
71    trait_generic: &TypeGenerics,
72    method: &TraitItemFn,
73) -> proc_macro2::TokenStream {
74    let TraitItemFn {
75        attrs,
76        sig,
77        default: _,
78        semi_token: _,
79    } = method;
80
81    let name = &sig.ident;
82
83    if is_valid_self(sig.inputs.first()).not() {
84        return Error::new_spanned(
85            method,
86            "Only methods with `self`, `&self` or `&mut self` are supported",
87        )
88        .to_compile_error();
89    }
90
91    let suffix = match sig.asyncness.is_some() {
92        false => quote! {},
93        true => quote! { .await },
94    };
95
96    if let ReturnType::Type(_, ty) = &sig.output
97        && let Type::ImplTrait(impl_trait) = ty.as_ref()
98    {
99        return Error::new_spanned(impl_trait, "Return impl is not supported").to_compile_error();
100    }
101
102    let remaining_inputs = sig.inputs.iter().skip(1).map(|arg| match arg {
103        FnArg::Receiver(rec) => {
104            Error::new_spanned(rec, "Self only as first argument please").to_compile_error()
105        }
106        FnArg::Typed(typed) => {
107            let name = typed.pat.as_ref();
108            quote! { , #name }
109        }
110    });
111
112    let generics = generics_for_method(&sig.generics);
113
114    quote! {
115        #(#attrs)* #sig {
116            match self {
117                $(
118                    Self::$variant_name(__static_dispatch_value) => <$variant_type as #trait_name #trait_generic>::#name #generics(
119                        __static_dispatch_value
120                        #(#remaining_inputs)*
121                    )#suffix,
122                )*
123            }
124        }
125    }
126}
127
128fn macro_name(ident: &Ident) -> Ident {
129    format_ident!("{}_static_dispatch_macro", ident)
130}
131
132fn dispatch_trait(attr: TokenStream, input: ItemTrait) -> proc_macro2::TokenStream {
133    let export = if attr.is_empty() {
134        false
135    } else {
136        let ident = match syn::parse::<Ident>(attr) {
137            Ok(ident) => ident,
138            Err(err) => return err.to_compile_error(),
139        };
140        if ident != "macro_export" {
141            return Error::new_spanned(&ident, "Only \"macro_export\" is allowed as attribute.")
142                .to_compile_error();
143        }
144        true
145    };
146
147    let trait_name = &input.ident;
148    let macro_name = macro_name(trait_name);
149    let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();
150
151    let items = input.items.iter().map(|item| match item {
152        TraitItem::Fn(method) => create_trait_item_macro(trait_name, ty_generics, method),
153        item => Error::new_spanned(item, "Only methods are supported").to_compile_error(),
154    });
155
156    let export_prefix = match export {
157        false => quote! {},
158        true => quote! { #[macro_export] },
159    };
160
161    let visibility = &input.vis;
162    let use_statement = match export {
163        false => quote! { #visibility use #macro_name; },
164        true => quote! {},
165    };
166
167    quote! {
168        /// This is just the macro static dispatch uses to create the implementation for the enum.
169        #export_prefix
170        macro_rules! #macro_name {
171            (
172                $vis:vis enum $name:ident {
173                    $($variant_name:ident($variant_type:ty),)*
174                }
175            ) => {
176                impl #impl_generics #trait_name #ty_generics for $name #where_clause {
177                    #(#items)*
178                }
179            };
180        }
181        #use_statement
182    }
183}
184
185fn edit_trait_path(trait_path: &mut Path) -> Result<(), proc_macro2::TokenStream> {
186    match trait_path.segments.last_mut() {
187        Some(segment) => {
188            segment.ident = macro_name(&segment.ident);
189            segment.arguments = PathArguments::None;
190            Ok(())
191        }
192        None => Err(
193            Error::new_spanned(trait_path, "Name or Path of the trait required").to_compile_error(),
194        ),
195    }
196}
197
198fn dispatch_enum(attr: TokenStream, input: ItemEnum) -> proc_macro2::TokenStream {
199    let enum_name = &input.ident;
200    let vis = &input.vis;
201    let variants = input.variants.iter();
202
203    let attr = proc_macro2::TokenStream::from(attr);
204
205    let Ok(mut trait_path) = syn::parse2::<Path>(attr.clone()) else {
206        return Error::new_spanned(attr, "Path or impl trait for type signature expected")
207            .to_compile_error();
208    };
209    if let Err(err) = edit_trait_path(&mut trait_path) {
210        return err;
211    }
212    quote! {
213        #trait_path! {
214            #vis enum #enum_name {
215                #(#variants,)*
216            }
217        }
218    }
219}