1use std::ops::Not;
2
3use proc_macro::TokenStream;
4use quote::{ToTokens, format_ident, quote};
5use syn::{
6 Error, FnArg, GenericParam, Generics, Ident, ItemEnum, ItemTrait, Path, PathArguments,
7 ReturnType, Token, TraitItem, TraitItemFn, Type, TypeGenerics, TypeReference, parse::Parse,
8};
9
10#[proc_macro_attribute]
11pub fn dispatch(attr: TokenStream, item: TokenStream) -> TokenStream {
13 let item = proc_macro2::TokenStream::from(item);
14
15 let output = if let Ok(input_trait) = syn::parse2(item.clone()) {
16 dispatch_trait(attr, input_trait)
17 } else if let Ok(input_trait) = syn::parse2(item.clone()) {
18 dispatch_enum(attr, input_trait)
19 } else {
20 Error::new_spanned(&item, "Could not parse as trait or enum").to_compile_error()
21 };
22
23 quote! {
24 #item
25 #output
26 }
27 .into()
28}
29
30fn is_self_type(ty: &Type) -> bool {
31 match ty {
32 Type::Path(type_path) => {
34 type_path.qself.is_none()
35 && type_path.path.segments.len() == 1
36 && type_path.path.segments[0].ident == "Self"
37 && matches!(type_path.path.segments[0].arguments, PathArguments::None)
38 }
39 Type::Reference(TypeReference { elem, .. }) => is_self_type(elem),
41 _ => false,
42 }
43}
44
45fn is_valid_self(arg: Option<&FnArg>) -> bool {
46 let Some(FnArg::Receiver(receiver)) = arg else {
47 return false;
48 };
49 receiver.colon_token.is_none() || is_self_type(&receiver.ty)
50}
51
52fn generics_for_method(generics: &Generics) -> proc_macro2::TokenStream {
53 let mut generics = generics.params.iter().filter_map(|generic| match generic {
54 GenericParam::Lifetime(_) => None,
55 GenericParam::Const(const_generic) => Some(&const_generic.ident),
56 GenericParam::Type(type_generic) => Some(&type_generic.ident),
57 });
58 let Some(first) = generics.next() else {
59 return proc_macro2::TokenStream::new();
60 };
61 let mut res = quote! {::<#first};
62 for generic in generics {
63 quote! {, #generic}.to_tokens(&mut res);
64 }
65 quote! {>}.to_tokens(&mut res);
66 res
67}
68
69fn create_trait_item_macro(
70 trait_name: &Ident,
71 trait_generic: &TypeGenerics,
72 method: &TraitItemFn,
73 long_form: bool,
74) -> proc_macro2::TokenStream {
75 let TraitItemFn {
76 attrs,
77 sig,
78 default: _,
79 semi_token: _,
80 } = method;
81
82 let name = &sig.ident;
83
84 if is_valid_self(sig.inputs.first()).not() {
85 return Error::new_spanned(
86 method,
87 "Only methods with `self`, `&self` or `&mut self` are supported",
88 )
89 .to_compile_error();
90 }
91
92 let suffix = match sig.asyncness.is_some() {
93 false => quote! {},
94 true => quote! { .await },
95 };
96
97 if let ReturnType::Type(_, ty) = &sig.output
98 && let Type::ImplTrait(impl_trait) = ty.as_ref()
99 {
100 return Error::new_spanned(impl_trait, "Return impl is not supported").to_compile_error();
101 }
102
103 let remaining_inputs = sig.inputs.iter().skip(1).map(|arg| match arg {
104 FnArg::Receiver(rec) => {
105 Error::new_spanned(rec, "Self only as first argument please").to_compile_error()
106 }
107 FnArg::Typed(typed) => {
108 let name = typed.pat.as_ref();
109 quote! { , #name }
110 }
111 });
112
113 let generics = generics_for_method(&sig.generics);
114
115 let trait_type = match long_form {
116 false => quote! { #trait_name #trait_generic },
117 true => quote! { $trait_type },
118 };
119
120 quote! {
121 #(#attrs)* #sig {
122 match self {
123 $(
124 Self::$variant_name(__static_dispatch_value) => <$variant_type as #trait_type>::#name #generics(
125 __static_dispatch_value
126 #(#remaining_inputs)*
127 )#suffix,
128 )*
129 }
130 }
131 }
132}
133
134fn macro_name(ident: &Ident) -> Ident {
135 format_ident!("{}_static_dispatch_macro", ident)
136}
137
138fn dispatch_trait(attr: TokenStream, input: ItemTrait) -> proc_macro2::TokenStream {
139 let export = if attr.is_empty() {
140 false
141 } else {
142 let ident = match syn::parse::<Ident>(attr) {
143 Ok(ident) => ident,
144 Err(err) => return err.to_compile_error(),
145 };
146 if ident != "macro_export" {
147 return Error::new_spanned(&ident, "Only \"macro_export\" is allowed as attribute.")
148 .to_compile_error();
149 }
150 true
151 };
152
153 let trait_name = &input.ident;
154 let macro_name = macro_name(trait_name);
155 let (impl_generics, ty_generics, where_clause) = &input.generics.split_for_impl();
156
157 let short_items = input.items.iter().map(|item| match item {
158 TraitItem::Fn(method) => create_trait_item_macro(trait_name, ty_generics, method, false),
159 item => Error::new_spanned(item, "Only methods are supported").to_compile_error(),
160 });
161
162 let long_items = input.items.iter().map(|item| match item {
163 TraitItem::Fn(method) => create_trait_item_macro(trait_name, ty_generics, method, true),
164 item => Error::new_spanned(item, "Only methods are supported").to_compile_error(),
165 });
166
167 let export_prefix = match export {
168 false => quote! {},
169 true => quote! { #[macro_export] },
170 };
171
172 let visibility = &input.vis;
173 let use_statement = match export {
174 false => quote! { #visibility use #macro_name; },
175 true => quote! {},
176 };
177
178 quote! {
179 #export_prefix
181 macro_rules! #macro_name {
182 (
183 short
184 $vis:vis enum $name:ident {
185 $($variant_name:ident($variant_type:ty),)*
186 }
187 ) => {
188 impl #impl_generics #trait_name #ty_generics for $name #where_clause {
189 #(#short_items)*
190 }
191 };
192 (
193 long
194 $trait_type:ty
195 {
196 $($variant_name:ident($variant_type:ty),)*
197 }
198 $($rem:tt)*
199 ) => {
200 $($rem)* {
201 #(#long_items)*
202 }
203 };
204 }
205 #use_statement
206 }
207}
208
209fn edit_trait_path(trait_path: &mut Path) -> Result<(), proc_macro2::TokenStream> {
210 match trait_path.segments.last_mut() {
211 Some(segment) => {
212 segment.ident = macro_name(&segment.ident);
213 segment.arguments = PathArguments::None;
214 Ok(())
215 }
216 None => Err(
217 Error::new_spanned(trait_path, "Name or Path of the trait required").to_compile_error(),
218 ),
219 }
220}
221
222struct LongImpl {
223 _impl: Token![impl],
224 generics: Generics,
225 trait_: Path,
226 _for: Token![for],
227 self_ty: Type,
228}
229
230impl Parse for LongImpl {
231 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
232 Ok(Self {
233 _impl: input.parse()?,
234 generics: input.parse()?,
235 trait_: input.parse()?,
236 _for: input.parse()?,
237 self_ty: input.parse()?,
238 })
239 }
240}
241
242fn dispatch_enum(attr: TokenStream, input: ItemEnum) -> proc_macro2::TokenStream {
243 let enum_name = &input.ident;
244 let vis = &input.vis;
245 let variants = input.variants.iter();
246
247 let attr = proc_macro2::TokenStream::from(attr);
248
249 if let Ok(mut trait_path) = syn::parse2::<Path>(attr.clone()) {
250 if let Err(err) = edit_trait_path(&mut trait_path) {
251 return err;
252 }
253 return quote! {
254 #trait_path! {
255 short
256 #vis enum #enum_name {
257 #(#variants,)*
258 }
259 }
260 };
261 }
262
263 let item_impl = match syn::parse2::<LongImpl>(attr) {
264 Ok(item_impl) => item_impl,
265 Err(err) => return err.into_compile_error(),
266 };
267
268 let mut trait_path = item_impl.trait_.clone();
269 if let Err(err) = edit_trait_path(&mut trait_path) {
270 return err;
271 }
272
273 let (impl_generics, _ty_generics, where_clause) = item_impl.generics.split_for_impl();
274 let trait_name = &item_impl.trait_;
275 let name = item_impl.self_ty;
276
277 quote! {
278
279 #trait_path! {
280 long
281 #trait_name
282 {
283 #(#variants,)*
284 }
285 impl #impl_generics #trait_name for #name #where_clause
286 }
287 }
288}