1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span as Span2, TokenStream as TokenStream2};
5use quote::{quote, ToTokens};
6use syn::{
7 parenthesized,
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11 Expr, GenericParam, Ident, Result, Token, Type,
12};
13
14#[derive(Debug, Eq, PartialEq, Clone)]
18enum FnArgName {
19 Ident(Ident),
20 Underscore(Token![_]),
21}
22
23impl Parse for FnArgName {
24 fn parse(input: ParseStream) -> Result<Self> {
25 if input.peek(Ident) {
26 Ok(Self::Ident(input.parse()?))
27 } else if input.peek(Token![_]) {
28 Ok(Self::Underscore(input.parse()?))
29 } else {
30 Err(input.error("expected identifier or underscore"))
31 }
32 }
33}
34
35impl ToTokens for FnArgName {
36 fn to_tokens(&self, tokens: &mut TokenStream2) {
37 match self {
38 Self::Ident(ident) => ident.to_tokens(tokens),
39 Self::Underscore(underscore) => underscore.to_tokens(tokens),
40 }
41 }
42}
43
44#[derive(Debug, Eq, PartialEq, Clone)]
46struct FnArg {
47 r#mut: Option<Token![mut]>,
48 name: FnArgName,
49 ty: Type,
50}
51
52impl Parse for FnArg {
53 fn parse(input: ParseStream) -> Result<Self> {
54 let r#mut = input.parse()?;
55 let name = input.parse()?;
56 let _ = input.parse::<Token![:]>()?;
57 let ty = input.parse()?;
58 Ok(Self { r#mut, name, ty })
59 }
60}
61
62impl ToTokens for FnArg {
63 fn to_tokens(&self, tokens: &mut TokenStream2) {
64 self.name.to_tokens(tokens);
65 Token).to_tokens(tokens);
66 self.ty.to_tokens(tokens);
67 }
68}
69
70#[derive(Debug, Eq, PartialEq)]
72struct DispatchArmExpr {
73 default: Option<Token![default]>,
74 generic_params: Option<Punctuated<GenericParam, Token![,]>>,
75 input_expr: FnArg,
76 extra_args: Vec<FnArg>,
77 body: Expr,
78}
79
80impl Parse for DispatchArmExpr {
81 fn parse(input: ParseStream) -> Result<Self> {
82 let default = input.parse::<Option<Token![default]>>()?;
83 let _ = input.parse::<Token![fn]>()?;
84 let generic_params = if input.peek(Token![<]) {
85 let _ = input.parse::<Token![<]>()?;
86 let generic_params =
87 Punctuated::<GenericParam, Token![,]>::parse_separated_nonempty(input)?;
88 let _ = input.parse::<Token![>]>()?;
89 Some(generic_params)
90 } else {
91 None
92 };
93 let input_expr_content;
94 let _ = parenthesized!(input_expr_content in input);
95 let input_expr = input_expr_content.parse()?;
96 let extra_args = if input_expr_content.peek(Token![,]) {
97 let _ = input_expr_content.parse::<Token![,]>()?;
98 Punctuated::<FnArg, Token![,]>::parse_separated_nonempty(&input_expr_content)?
99 .into_iter()
100 .collect()
101 } else {
102 Vec::new()
103 };
104 let _ = input.parse::<Token![=>]>()?;
105 let body = input.parse()?;
106 Ok(Self {
107 default,
108 generic_params,
109 input_expr,
110 extra_args,
111 body,
112 })
113 }
114}
115
116#[derive(Debug, Eq, PartialEq)]
119struct SpecializedDispatchExpr {
120 from_type: Type,
121 to_type: Type,
122 arms: Vec<DispatchArmExpr>,
123 input_expr: Expr,
124 extra_args: Vec<Expr>,
125}
126
127fn parse_punctuated_arms(input: &ParseStream) -> Result<Punctuated<DispatchArmExpr, Token![,]>> {
129 let mut arms = Punctuated::new();
130 loop {
131 if input.peek(Token![default]) || input.peek(Token![fn]) {
132 arms.push(input.parse()?);
133 } else {
134 break;
135 }
136 if input.peek(Token![,]) && (input.peek2(Token![default]) || input.peek2(Token![fn])) {
137 let _ = input.parse::<Token![,]>()?;
138 } else {
139 break;
140 }
141 }
142 Ok(arms)
143}
144
145impl Parse for SpecializedDispatchExpr {
146 fn parse(input: ParseStream) -> Result<Self> {
147 let from_type = input.parse()?;
148 let _ = input.parse::<Token![->]>()?;
149 let to_type = input.parse()?;
150 let _ = input.parse::<Token![,]>()?;
151 let arms = parse_punctuated_arms(&input)?.into_iter().collect();
152 let _ = input.parse::<Token![,]>()?;
153 let input_expr = input.parse()?;
154 let _ = input.parse::<Token![,]>().ok();
155 let extra_args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?
156 .into_iter()
157 .collect();
158 Ok(Self {
159 from_type,
160 to_type,
161 arms,
162 input_expr,
163 extra_args,
164 })
165 }
166}
167
168fn generate_trait_declaration(
170 trait_name: &Ident,
171 extra_args: &[FnArg],
172 return_type: &Type,
173) -> TokenStream2 {
174 let tpl = Ident::new("T", Span2::mixed_site());
176 quote! {
177 trait #trait_name<#tpl> {
178 fn dispatch(_: #tpl #(, #extra_args)*) -> #return_type;
179 }
180 }
181}
182
183fn generate_trait_implementation(
186 default: Option<&Token![default]>,
187 trait_name: &Ident,
188 generic_params: Option<&Punctuated<GenericParam, Token![,]>>,
189 FnArg {
190 r#mut: input_expr_mut,
191 name: input_expr_name,
192 ty: input_expr_type,
193 }: &FnArg,
194 extra_args: &[FnArg],
195 return_type: &Type,
196 body: &Expr,
197) -> TokenStream2 {
198 let generics = generic_params.map(|g| quote! {<#g>});
199 quote! {
200 impl #generics #trait_name<#input_expr_type> for #input_expr_type {
201 #default fn dispatch(#input_expr_mut #input_expr_name: #input_expr_type #(, #extra_args)*) -> #return_type {
202 #body
203 }
204 }
205 }
206}
207
208fn generate_dispatch_call(
210 from_type: &Type,
211 trait_name: &Ident,
212 input_expr: &Expr,
213 extra_args: &[Expr],
214) -> TokenStream2 {
215 quote! {
216 <#from_type as #trait_name<#from_type>>::dispatch(#input_expr #(, #extra_args)*)
217 }
218}
219
220impl ToTokens for SpecializedDispatchExpr {
221 fn to_tokens(&self, tokens: &mut TokenStream2) {
222 let trait_name = Ident::new("SpecializedDispatchCall", Span2::mixed_site());
223 let mut trait_impls = TokenStream2::new();
224 let mut extra_args = None;
225
226 for arm in &self.arms {
227 if arm.default.is_some() && extra_args.is_none() {
228 extra_args = Some(&arm.extra_args);
229 }
230 trait_impls.extend(generate_trait_implementation(
231 arm.default.as_ref(),
232 &trait_name,
233 arm.generic_params.as_ref(),
234 &arm.input_expr,
235 &arm.extra_args,
236 &self.to_type,
237 &arm.body,
238 ));
239 }
240
241 let trait_decl = generate_trait_declaration(
242 &trait_name,
243 extra_args.unwrap_or(&Vec::new()),
244 &self.to_type,
245 );
246
247 let dispatch_call = generate_dispatch_call(
248 &self.from_type,
249 &trait_name,
250 &self.input_expr,
251 &self.extra_args,
252 );
253
254 tokens.extend(quote! {
255 {
256 #trait_decl
257 #trait_impls
258 #dispatch_call
259 }
260 });
261 }
262}
263
264#[proc_macro]
267pub fn specialized_dispatch(input: TokenStream) -> TokenStream {
268 parse_macro_input!(input as SpecializedDispatchExpr)
269 .into_token_stream()
270 .into()
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use syn::parse_quote;
277
278 #[test]
279 fn parse_arm_with_concrete_type() {
280 let arm: DispatchArmExpr = parse_quote!(fn (v: u8) => format!("u8: {}", v));
281 assert_eq!(
282 arm,
283 DispatchArmExpr {
284 default: None,
285 generic_params: None,
286 input_expr: parse_quote!(v: u8),
287 extra_args: vec![],
288 body: parse_quote!(format!("u8: {}", v)),
289 }
290 );
291 }
292
293 #[test]
294 fn parse_arm_with_generic_type() {
295 let arm: DispatchArmExpr = parse_quote!(default fn <T>(_: T) => format!("default value"));
296 assert_eq!(
297 arm,
298 DispatchArmExpr {
299 default: Some(Default::default()),
300 generic_params: Some(parse_quote!(T)),
301 input_expr: parse_quote!(_: T),
302 extra_args: vec![],
303 body: parse_quote!(format!("default value")),
304 }
305 );
306 }
307
308 #[test]
309 fn parse_specialized_dispatch_expr() {
310 let expr: SpecializedDispatchExpr = parse_quote! {
311 E -> String,
312 default fn <T>(_: T) => format!("default value"),
313 fn (v: u8) => format!("u8: {}", v),
314 fn (v: u16) => format!("u16: {}", v),
315 expr,
316 };
317 assert_eq!(
318 expr,
319 SpecializedDispatchExpr {
320 from_type: parse_quote!(E),
321 to_type: parse_quote!(String),
322 arms: vec![
323 DispatchArmExpr {
324 default: Some(Default::default()),
325 generic_params: Some(parse_quote!(T)),
326 input_expr: parse_quote!(_: T),
327 extra_args: vec![],
328 body: parse_quote!(format!("default value")),
329 },
330 DispatchArmExpr {
331 default: None,
332 generic_params: None,
333 input_expr: parse_quote!(v: u8),
334 extra_args: vec![],
335 body: parse_quote!(format!("u8: {}", v)),
336 },
337 DispatchArmExpr {
338 default: None,
339 generic_params: None,
340 input_expr: parse_quote!(v: u16),
341 extra_args: vec![],
342 body: parse_quote!(format!("u16: {}", v)),
343 },
344 ],
345 input_expr: parse_quote!(expr),
346 extra_args: vec![],
347 }
348 );
349 }
350
351 #[test]
352 fn parse_trailing_args() {
353 let expr: SpecializedDispatchExpr = parse_quote! {
354 E -> String,
355 default fn <T>(_: T, arg1: u8, arg2: u16, arg3: &str) => format!("default value"),
356 fn (v: u8, arg1: u8, arg2: u16, arg3: &str) => format!("u8: {}", v),
357 fn (v: u16, arg1: u8, arg2: u16, arg3: &str) => format!("u16: {}", v),
358 expr,
359 1u8,
360 2u16,
361 "bugun_bayram_erken_kalkin_cocuklar",
362 };
363
364 assert_eq!(
365 expr,
366 SpecializedDispatchExpr {
367 from_type: parse_quote!(E),
368 to_type: parse_quote!(String),
369 arms: vec![
370 DispatchArmExpr {
371 default: Some(Default::default()),
372 generic_params: Some(parse_quote!(T)),
373 input_expr: parse_quote!(_: T),
374 extra_args: vec![
375 parse_quote!(arg1: u8),
376 parse_quote!(arg2: u16),
377 parse_quote!(arg3: &str)
378 ],
379 body: parse_quote!(format!("default value")),
380 },
381 DispatchArmExpr {
382 default: None,
383 generic_params: None,
384 input_expr: parse_quote!(v: u8),
385 extra_args: vec![
386 parse_quote!(arg1: u8),
387 parse_quote!(arg2: u16),
388 parse_quote!(arg3: &str)
389 ],
390 body: parse_quote!(format!("u8: {}", v)),
391 },
392 DispatchArmExpr {
393 default: None,
394 generic_params: None,
395 input_expr: parse_quote!(v: u16),
396 extra_args: vec![
397 parse_quote!(arg1: u8),
398 parse_quote!(arg2: u16),
399 parse_quote!(arg3: &str)
400 ],
401 body: parse_quote!(format!("u16: {}", v)),
402 },
403 ],
404 input_expr: parse_quote!(expr),
405 extra_args: vec![
406 parse_quote!(1u8),
407 parse_quote!(2u16),
408 parse_quote!("bugun_bayram_erken_kalkin_cocuklar")
409 ],
410 }
411 );
412 }
413
414 #[test]
415 fn parse_mut_arg() {
416 let arg: FnArg = parse_quote!(mut v: u8);
417 assert_eq!(
418 arg,
419 FnArg {
420 r#mut: Some(parse_quote!(mut)),
421 ty: parse_quote!(u8),
422 name: FnArgName::Ident(parse_quote!(v)),
423 }
424 );
425 }
426}