static_dispatch_macros/
lib.rs

1use std::ops::Not;
2
3use proc_macro::TokenStream;
4use quote::{ToTokens, format_ident, quote};
5use syn::{
6    Error, FnArg, GenericParam, Generics, Ident, ItemEnum, ItemTrait, Path, PathArguments,
7    ReturnType, Token, TraitItem, TraitItemFn, Type, TypeGenerics, TypeReference, parse::Parse,
8};
9
10#[proc_macro_attribute]
11/// See the module for documentation.
12pub fn dispatch(attr: TokenStream, item: TokenStream) -> TokenStream {
13    let item = proc_macro2::TokenStream::from(item);
14
15    let output = if let Ok(input_trait) = syn::parse2(item.clone()) {
16        dispatch_trait(attr, input_trait)
17    } else if let Ok(input_trait) = syn::parse2(item.clone()) {
18        dispatch_enum(attr, input_trait)
19    } else {
20        Error::new_spanned(&item, "Could not parse as trait or enum").to_compile_error()
21    };
22
23    quote! {
24        #item
25        #output
26    }
27    .into()
28}
29
30fn is_self_type(ty: &Type) -> bool {
31    match ty {
32        // Check for plain `Self`
33        Type::Path(type_path) => {
34            type_path.qself.is_none()
35                && type_path.path.segments.len() == 1
36                && type_path.path.segments[0].ident == "Self"
37                && matches!(type_path.path.segments[0].arguments, PathArguments::None)
38        }
39        // Check for `&Self` or `&mut Self`
40        Type::Reference(TypeReference { elem, .. }) => is_self_type(elem),
41        _ => false,
42    }
43}
44
45fn is_valid_self(arg: Option<&FnArg>) -> bool {
46    let Some(FnArg::Receiver(receiver)) = arg else {
47        return false;
48    };
49    receiver.colon_token.is_none() || is_self_type(&receiver.ty)
50}
51
52fn generics_for_method(generics: &Generics) -> proc_macro2::TokenStream {
53    let mut generics = generics.params.iter().filter_map(|generic| match generic {
54        GenericParam::Lifetime(_) => None,
55        GenericParam::Const(const_generic) => Some(&const_generic.ident),
56        GenericParam::Type(type_generic) => Some(&type_generic.ident),
57    });
58    let Some(first) = generics.next() else {
59        return proc_macro2::TokenStream::new();
60    };
61    let mut res = quote! {::<#first};
62    for generic in generics {
63        quote! {, #generic}.to_tokens(&mut res);
64    }
65    quote! {>}.to_tokens(&mut res);
66    res
67}
68
69fn create_trait_item_macro(
70    trait_name: &Ident,
71    trait_generic: &TypeGenerics,
72    method: &TraitItemFn,
73    long_form: bool,
74) -> proc_macro2::TokenStream {
75    let TraitItemFn {
76        attrs,
77        sig,
78        default: _,
79        semi_token: _,
80    } = method;
81
82    let name = &sig.ident;
83
84    if is_valid_self(sig.inputs.first()).not() {
85        return Error::new_spanned(
86            method,
87            "Only methods with `self`, `&self` or `&mut self` are supported",
88        )
89        .to_compile_error();
90    }
91
92    let suffix = match sig.asyncness.is_some() {
93        false => quote! {},
94        true => quote! { .await },
95    };
96
97    if let ReturnType::Type(_, ty) = &sig.output
98        && let Type::ImplTrait(impl_trait) = ty.as_ref()
99    {
100        return Error::new_spanned(impl_trait, "Return impl is not supported").to_compile_error();
101    }
102
103    let remaining_inputs = sig.inputs.iter().skip(1).map(|arg| match arg {
104        FnArg::Receiver(rec) => {
105            Error::new_spanned(rec, "Self only as first argument please").to_compile_error()
106        }
107        FnArg::Typed(typed) => {
108            let name = typed.pat.as_ref();
109            quote! { , #name }
110        }
111    });
112
113    let generics = generics_for_method(&sig.generics);
114
115    let trait_type = match long_form {
116        false => quote! { #trait_name #trait_generic },
117        true => quote! { $trait_type },
118    };
119
120    quote! {
121        #(#attrs)* #sig {
122            match self {
123                $(
124                    Self::$variant_name(__static_dispatch_value) => <$variant_type as #trait_type>::#name #generics(
125                        __static_dispatch_value
126                        #(#remaining_inputs)*
127                    )#suffix,
128                )*
129            }
130        }
131    }
132}
133
134fn macro_name(ident: &Ident) -> Ident {
135    format_ident!("{}_static_dispatch_macro", ident)
136}
137
138fn dispatch_trait(attr: TokenStream, input: ItemTrait) -> proc_macro2::TokenStream {
139    let export = if attr.is_empty() {
140        false
141    } else {
142        let ident = match syn::parse::<Ident>(attr) {
143            Ok(ident) => ident,
144            Err(err) => return err.to_compile_error(),
145        };
146        if ident != "macro_export" {
147            return Error::new_spanned(&ident, "Only \"macro_export\" is allowed as attribute.")
148                .to_compile_error();
149        }
150        true
151    };
152
153    let trait_name = &input.ident;
154    let macro_name = macro_name(trait_name);
155    let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();
156
157    let short_items = input.items.iter().map(|item| match item {
158        TraitItem::Fn(method) => create_trait_item_macro(trait_name, ty_generics, method, false),
159        item => Error::new_spanned(item, "Only methods are supported").to_compile_error(),
160    });
161
162    let long_items = input.items.iter().map(|item| match item {
163        TraitItem::Fn(method) => create_trait_item_macro(trait_name, ty_generics, method, true),
164        item => Error::new_spanned(item, "Only methods are supported").to_compile_error(),
165    });
166
167    let export_prefix = match export {
168        false => quote! {},
169        true => quote! { #[macro_export] },
170    };
171
172    let visibility = &input.vis;
173    let use_statement = match export {
174        false => quote! { #visibility use #macro_name; },
175        true => quote! {},
176    };
177
178    quote! {
179        /// This is just the macro static dispatch uses to create the implementation for the enum.
180        #export_prefix
181        macro_rules! #macro_name {
182            (
183                short
184                $vis:vis enum $name:ident {
185                    $($variant_name:ident($variant_type:ty),)*
186                }
187            ) => {
188                impl #impl_generics #trait_name #ty_generics for $name #where_clause {
189                    #(#short_items)*
190                }
191            };
192            (
193                long
194                $trait_type:ty
195                {
196                    $($variant_name:ident($variant_type:ty),)*
197                }
198                $($rem:tt)*
199            ) => {
200                $($rem)* {
201                    #(#long_items)*
202                }
203            };
204        }
205        #use_statement
206    }
207}
208
209fn edit_trait_path(trait_path: &mut Path) -> Result<(), proc_macro2::TokenStream> {
210    match trait_path.segments.last_mut() {
211        Some(segment) => {
212            segment.ident = macro_name(&segment.ident);
213            segment.arguments = PathArguments::None;
214            Ok(())
215        }
216        None => Err(
217            Error::new_spanned(trait_path, "Name or Path of the trait required").to_compile_error(),
218        ),
219    }
220}
221
222struct LongImpl {
223    _impl: Token![impl],
224    generics: Generics,
225    trait_: Path,
226    _for: Token![for],
227    self_ty: Type,
228}
229
230impl Parse for LongImpl {
231    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
232        Ok(Self {
233            _impl: input.parse()?,
234            generics: input.parse()?,
235            trait_: input.parse()?,
236            _for: input.parse()?,
237            self_ty: input.parse()?,
238        })
239    }
240}
241
242fn dispatch_enum(attr: TokenStream, input: ItemEnum) -> proc_macro2::TokenStream {
243    let enum_name = &input.ident;
244    let vis = &input.vis;
245    let variants = input.variants.iter();
246
247    let attr = proc_macro2::TokenStream::from(attr);
248
249    if let Ok(mut trait_path) = syn::parse2::<Path>(attr.clone()) {
250        if let Err(err) = edit_trait_path(&mut trait_path) {
251            return err;
252        }
253        return quote! {
254            #trait_path! {
255                short
256                #vis enum #enum_name {
257                    #(#variants,)*
258                }
259            }
260        };
261    }
262
263    let item_impl = match syn::parse2::<LongImpl>(attr) {
264        Ok(item_impl) => item_impl,
265        Err(err) => return err.into_compile_error(),
266    };
267
268    let mut trait_path = item_impl.trait_.clone();
269    if let Err(err) = edit_trait_path(&mut trait_path) {
270        return err;
271    }
272
273    let (impl_generics, _ty_generics, where_clause) = item_impl.generics.split_for_impl();
274    let trait_name = &item_impl.trait_;
275    let name = item_impl.self_ty;
276
277    quote! {
278
279        #trait_path! {
280            long
281            #trait_name
282            {
283                #(#variants,)*
284            }
285            impl #impl_generics #trait_name for #name #where_clause
286        }
287    }
288}