static_dispatch_macros/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{ToTokens, quote};
4use syn::{
5 Error, Fields, FnArg, GenericParam, Item, ItemEnum, ItemTrait, Path, Token, TraitItem,
6 WhereClause, parse::Parse, parse_macro_input,
7};
8
9#[proc_macro_attribute]
13pub fn setup(_attr: TokenStream, item: TokenStream) -> TokenStream {
14 let input = parse_macro_input!(item as Item);
16 let name = match &input {
17 Item::Trait(value) => &value.ident,
18 Item::Enum(value) => &value.ident,
19 _ => {
20 return Error::new_spanned(&input, "dispatch is only valid on traits or enums")
21 .to_compile_error()
22 .into();
23 }
24 };
25
26 let save = macro_data::save(name, &input);
27
28 quote! {
29 #input
30 #save
31 }
32 .into()
33}
34
35struct GenerateInput {
36 trait_name: Path,
37 _for: Token![for],
38 enum_name: Path,
39}
40
41impl Parse for GenerateInput {
42 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
43 Ok(Self {
44 trait_name: input.parse()?,
45 _for: input.parse()?,
46 enum_name: input.parse()?,
47 })
48 }
49}
50
51#[proc_macro]
59pub fn implementation(input: TokenStream) -> TokenStream {
60 let input = parse_macro_input!(input as GenerateInput);
61
62 let data = FinalTransfer {
63 trait_item: macro_data::request(&input.trait_name),
64 comma: syn::token::Comma(Span::mixed_site()),
65 enum_item: macro_data::request(&input.enum_name),
66 };
67
68 macro_data::transfer("static_dispatch", "generate_final", &data).into()
69}
70
71struct FinalTransfer<S: macro_data::Storage> {
72 trait_item: macro_data::Transfer<ItemTrait, S>,
73 comma: Token![,],
74 enum_item: macro_data::Transfer<ItemEnum, S>,
75}
76
77impl ToTokens for FinalTransfer<macro_data::Request> {
78 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
79 self.trait_item.to_tokens(tokens);
80 self.comma.to_tokens(tokens);
81 self.enum_item.to_tokens(tokens);
82 }
83}
84
85impl Parse for FinalTransfer<macro_data::Load> {
86 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
87 Ok(Self {
88 trait_item: input.parse()?,
89 comma: input.parse()?,
90 enum_item: input.parse()?,
91 })
92 }
93}
94
95#[doc(hidden)]
99#[proc_macro]
100pub fn generate_final(input: TokenStream) -> TokenStream {
101 let input = parse_macro_input!(input as FinalTransfer<macro_data::Load>);
102 let trait_item = input.trait_item.0;
103 let enum_item = input.enum_item.0;
104
105 let trait_ident = &trait_item.ident;
106 let enum_ident = &enum_item.ident;
107
108 let mut all_params = Vec::new();
110 for param in &trait_item.generics.params {
111 all_params.push(param.clone());
112 }
113 for param in &enum_item.generics.params {
114 all_params.push(param.clone());
115 }
116 all_params.sort_by_key(|param| match param {
117 GenericParam::Lifetime(_) => 0,
118 GenericParam::Const(_) => 1,
119 GenericParam::Type(_) => 2,
120 });
121
122 let impl_generics = if all_params.is_empty() {
123 quote! {}
124 } else {
125 quote! { < #(#all_params),* > }
126 };
127
128 let mut where_predicates = Vec::new();
130 if let Some(wc) = &trait_item.generics.where_clause {
131 where_predicates.extend(wc.predicates.iter().cloned());
132 }
133 if let Some(wc) = &enum_item.generics.where_clause {
134 where_predicates.extend(wc.predicates.iter().cloned());
135 }
136 all_params.sort_by_key(|param| match param {
137 GenericParam::Lifetime(_) => 0,
138 GenericParam::Const(_) => 1,
139 GenericParam::Type(_) => 2,
140 });
141
142 let where_clause = if where_predicates.is_empty() {
143 None
144 } else {
145 Some(WhereClause {
146 where_token: syn::token::Where::default(),
147 predicates: syn::punctuated::Punctuated::from_iter(where_predicates),
148 })
149 };
150 let trait_args = generic_args(&trait_item.generics);
151 let enum_args = generic_args(&enum_item.generics);
152
153 let impl_methods = trait_item
155 .items
156 .iter()
157 .map(|item| {
158 let TraitItem::Fn(method) = item else {
159 return Error::new_spanned(item, "Only methods are supported").to_compile_error();
160 };
161 let sig = &method.sig;
162 let method_name = &sig.ident;
163 let method_gen = sig
164 .generics
165 .params
166 .iter()
167 .filter_map(|param| match param {
168 GenericParam::Lifetime(_) => None,
169 GenericParam::Const(param) => Some(¶m.ident),
170 GenericParam::Type(param) => Some(¶m.ident),
171 })
172 .collect::<Vec<_>>();
173
174 let mut args = sig.inputs.iter();
175 let self_arg = match args.next() {
176 Some(FnArg::Receiver(rec)) => &rec.self_token,
177 _ => {
178 return Error::new_spanned(sig, "Function requires self argument")
179 .to_compile_error();
180 }
181 };
182
183 let args = sig
184 .inputs
185 .iter()
186 .skip(1)
187 .map(|arg| {
188 if let syn::FnArg::Typed(pat_type) = arg {
189 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
190 pat_ident.ident.clone()
191 } else {
192 panic!("Unsupported argument pattern");
193 }
194 } else {
195 panic!("Expected typed argument");
196 }
197 })
198 .collect::<Vec<_>>();
199
200 let async_suffix = match sig.asyncness {
201 None => quote! {},
202 Some(_) => quote! {.await},
203 };
204
205 let arms = enum_item
207 .variants
208 .iter()
209 .map(|variant| {
210 let variant_ident = &variant.ident;
211 let Fields::Unnamed(fields) = &variant.fields else {
212 panic!("Only enum tuples supported");
213 };
214 let field = fields.unnamed.iter().next().expect("expected a field");
215 let field_type = &field.ty;
216 let method_gen = quote! { ::<#(#method_gen,)*> };
217 quote! {
218 #enum_ident::#variant_ident(__static_dispatch_value) =>
219 <#field_type as #trait_ident #trait_args>::#method_name #method_gen(
220 __static_dispatch_value,
221 #(#args),*
222 ) #async_suffix
223 }
224 })
225 .collect::<Vec<_>>();
226
227 quote! {
228 #sig {
229 match #self_arg {
230 #(#arms,)*
231 }
232 }
233 }
234 })
235 .collect::<Vec<_>>();
236
237 let expanded = quote! {
238 impl #impl_generics #trait_ident #trait_args for #enum_ident #enum_args #where_clause {
239 #(#impl_methods)*
240 }
241 };
242
243 expanded.into()
244}
245
246fn generic_args(generics: &syn::Generics) -> proc_macro2::TokenStream {
247 let args: Vec<_> = generics
248 .params
249 .iter()
250 .map(|param| match param {
251 GenericParam::Type(ty) => {
252 let ident = &ty.ident;
253 quote! { #ident }
254 }
255 GenericParam::Lifetime(lifetime) => {
256 let lt = &lifetime.lifetime;
257 quote! { #lt }
258 }
259 GenericParam::Const(c) => {
260 let ident = &c.ident;
261 quote! { #ident }
262 }
263 })
264 .collect();
265 if args.is_empty() {
266 quote! {}
267 } else {
268 quote! { < #(#args),* > }
269 }
270}