static_dispatch_macros/
lib.rs1use std::ops::Not;
2
3use proc_macro::TokenStream;
4use quote::{ToTokens, format_ident, quote};
5use syn::{
6 Error, FnArg, GenericParam, Generics, Ident, ItemEnum, ItemTrait, Path, PathArguments,
7 ReturnType, TraitItem, TraitItemFn, Type, TypeGenerics, TypeReference,
8};
9
10#[proc_macro_attribute]
11pub fn dispatch(attr: TokenStream, item: TokenStream) -> TokenStream {
13 let item = proc_macro2::TokenStream::from(item);
14
15 let output = if let Ok(input_trait) = syn::parse2(item.clone()) {
16 dispatch_trait(attr, input_trait)
17 } else if let Ok(input_trait) = syn::parse2(item.clone()) {
18 dispatch_enum(attr, input_trait)
19 } else {
20 Error::new_spanned(&item, "Could not parse as trait or enum").to_compile_error()
21 };
22
23 quote! {
24 #item
25 #output
26 }
27 .into()
28}
29
30fn is_self_type(ty: &Type) -> bool {
31 match ty {
32 Type::Path(type_path) => {
34 type_path.qself.is_none()
35 && type_path.path.segments.len() == 1
36 && type_path.path.segments[0].ident == "Self"
37 && matches!(type_path.path.segments[0].arguments, PathArguments::None)
38 }
39 Type::Reference(TypeReference { elem, .. }) => is_self_type(elem),
41 _ => false,
42 }
43}
44
45fn is_valid_self(arg: Option<&FnArg>) -> bool {
46 let Some(FnArg::Receiver(receiver)) = arg else {
47 return false;
48 };
49 receiver.colon_token.is_none() || is_self_type(&receiver.ty)
50}
51
52fn generics_for_method(generics: &Generics) -> proc_macro2::TokenStream {
53 let mut generics = generics.params.iter().filter_map(|generic| match generic {
54 GenericParam::Lifetime(_) => None,
55 GenericParam::Const(const_generic) => Some(&const_generic.ident),
56 GenericParam::Type(type_generic) => Some(&type_generic.ident),
57 });
58 let Some(first) = generics.next() else {
59 return proc_macro2::TokenStream::new();
60 };
61 let mut res = quote! {::<#first};
62 for generic in generics {
63 quote! {, #generic}.to_tokens(&mut res);
64 }
65 quote! {>}.to_tokens(&mut res);
66 res
67}
68
69fn create_trait_item_macro(
70 trait_name: &Ident,
71 trait_generic: &TypeGenerics,
72 method: &TraitItemFn,
73) -> proc_macro2::TokenStream {
74 let TraitItemFn {
75 attrs,
76 sig,
77 default: _,
78 semi_token: _,
79 } = method;
80
81 let name = &sig.ident;
82
83 if is_valid_self(sig.inputs.first()).not() {
84 return Error::new_spanned(
85 method,
86 "Only methods with `self`, `&self` or `&mut self` are supported",
87 )
88 .to_compile_error();
89 }
90
91 let suffix = match sig.asyncness.is_some() {
92 false => quote! {},
93 true => quote! { .await },
94 };
95
96 if let ReturnType::Type(_, ty) = &sig.output
97 && let Type::ImplTrait(impl_trait) = ty.as_ref()
98 {
99 return Error::new_spanned(impl_trait, "Return impl is not supported").to_compile_error();
100 }
101
102 let remaining_inputs = sig.inputs.iter().skip(1).map(|arg| match arg {
103 FnArg::Receiver(rec) => {
104 Error::new_spanned(rec, "Self only as first argument please").to_compile_error()
105 }
106 FnArg::Typed(typed) => {
107 let name = typed.pat.as_ref();
108 quote! { , #name }
109 }
110 });
111
112 let generics = generics_for_method(&sig.generics);
113
114 quote! {
115 #(#attrs)* #sig {
116 match self {
117 $(
118 Self::$variant_name(__static_dispatch_value) => <$variant_type as #trait_name #trait_generic>::#name #generics(
119 __static_dispatch_value
120 #(#remaining_inputs)*
121 )#suffix,
122 )*
123 }
124 }
125 }
126}
127
128fn macro_name(ident: &Ident) -> Ident {
129 format_ident!("{}_static_dispatch_macro", ident)
130}
131
132fn dispatch_trait(attr: TokenStream, input: ItemTrait) -> proc_macro2::TokenStream {
133 let export = if attr.is_empty() {
134 false
135 } else {
136 let ident = match syn::parse::<Ident>(attr) {
137 Ok(ident) => ident,
138 Err(err) => return err.to_compile_error(),
139 };
140 if ident != "macro_export" {
141 return Error::new_spanned(&ident, "Only \"macro_export\" is allowed as attribute.")
142 .to_compile_error();
143 }
144 true
145 };
146
147 let trait_name = &input.ident;
148 let macro_name = macro_name(trait_name);
149 let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();
150
151 let items = input.items.iter().map(|item| match item {
152 TraitItem::Fn(method) => create_trait_item_macro(trait_name, ty_generics, method),
153 item => Error::new_spanned(item, "Only methods are supported").to_compile_error(),
154 });
155
156 let export_prefix = match export {
157 false => quote! {},
158 true => quote! { #[macro_export] },
159 };
160
161 let visibility = &input.vis;
162 let use_statement = match export {
163 false => quote! { #visibility use #macro_name; },
164 true => quote! {},
165 };
166
167 quote! {
168 #export_prefix
170 macro_rules! #macro_name {
171 (
172 $vis:vis enum $name:ident {
173 $($variant_name:ident($variant_type:ty),)*
174 }
175 ) => {
176 impl #impl_generics #trait_name #ty_generics for $name #where_clause {
177 #(#items)*
178 }
179 };
180 }
181 #use_statement
182 }
183}
184
185fn edit_trait_path(trait_path: &mut Path) -> Result<(), proc_macro2::TokenStream> {
186 match trait_path.segments.last_mut() {
187 Some(segment) => {
188 segment.ident = macro_name(&segment.ident);
189 segment.arguments = PathArguments::None;
190 Ok(())
191 }
192 None => Err(
193 Error::new_spanned(trait_path, "Name or Path of the trait required").to_compile_error(),
194 ),
195 }
196}
197
198fn dispatch_enum(attr: TokenStream, input: ItemEnum) -> proc_macro2::TokenStream {
199 let enum_name = &input.ident;
200 let vis = &input.vis;
201 let variants = input.variants.iter();
202
203 let attr = proc_macro2::TokenStream::from(attr);
204
205 let Ok(mut trait_path) = syn::parse2::<Path>(attr.clone()) else {
206 return Error::new_spanned(attr, "Path or impl trait for type signature expected")
207 .to_compile_error();
208 };
209 if let Err(err) = edit_trait_path(&mut trait_path) {
210 return err;
211 }
212 quote! {
213 #trait_path! {
214 #vis enum #enum_name {
215 #(#variants,)*
216 }
217 }
218 }
219}