specialized_dispatch/
lib.rs

1#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span as Span2, TokenStream as TokenStream2};
5use quote::{quote, ToTokens};
6use syn::{
7    parenthesized,
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11    Expr, GenericParam, Ident, Result, Token, Type,
12};
13
14/// Parses either an identifier or an underscore for arguments of specializations.
15// TODO(ozars): Make this accept patterns for unpacking arguments. Maybe switch to using
16// `syn::PatType`.
17#[derive(Debug, Eq, PartialEq, Clone)]
18enum FnArgName {
19    Ident(Ident),
20    Underscore(Token![_]),
21}
22
23impl Parse for FnArgName {
24    fn parse(input: ParseStream) -> Result<Self> {
25        if input.peek(Ident) {
26            Ok(Self::Ident(input.parse()?))
27        } else if input.peek(Token![_]) {
28            Ok(Self::Underscore(input.parse()?))
29        } else {
30            Err(input.error("expected identifier or underscore"))
31        }
32    }
33}
34
35impl ToTokens for FnArgName {
36    fn to_tokens(&self, tokens: &mut TokenStream2) {
37        match self {
38            Self::Ident(ident) => ident.to_tokens(tokens),
39            Self::Underscore(underscore) => underscore.to_tokens(tokens),
40        }
41    }
42}
43
44/// Function argument with name and type.
45#[derive(Debug, Eq, PartialEq, Clone)]
46struct FnArg {
47    r#mut: Option<Token![mut]>,
48    name: FnArgName,
49    ty: Type,
50}
51
52impl Parse for FnArg {
53    fn parse(input: ParseStream) -> Result<Self> {
54        let r#mut = input.parse()?;
55        let name = input.parse()?;
56        let _ = input.parse::<Token![:]>()?;
57        let ty = input.parse()?;
58        Ok(Self { r#mut, name, ty })
59    }
60}
61
62impl ToTokens for FnArg {
63    fn to_tokens(&self, tokens: &mut TokenStream2) {
64        self.name.to_tokens(tokens);
65        Token![:](Span2::mixed_site()).to_tokens(tokens);
66        self.ty.to_tokens(tokens);
67    }
68}
69
70/// Represents an arm for specialized dispatch macro.
71#[derive(Debug, Eq, PartialEq)]
72struct DispatchArmExpr {
73    default: Option<Token![default]>,
74    generic_params: Option<Punctuated<GenericParam, Token![,]>>,
75    input_expr: FnArg,
76    extra_args: Vec<FnArg>,
77    body: Expr,
78}
79
80impl Parse for DispatchArmExpr {
81    fn parse(input: ParseStream) -> Result<Self> {
82        let default = input.parse::<Option<Token![default]>>()?;
83        let _ = input.parse::<Token![fn]>()?;
84        let generic_params = if input.peek(Token![<]) {
85            let _ = input.parse::<Token![<]>()?;
86            let generic_params =
87                Punctuated::<GenericParam, Token![,]>::parse_separated_nonempty(input)?;
88            let _ = input.parse::<Token![>]>()?;
89            Some(generic_params)
90        } else {
91            None
92        };
93        let input_expr_content;
94        let _ = parenthesized!(input_expr_content in input);
95        let input_expr = input_expr_content.parse()?;
96        let extra_args = if input_expr_content.peek(Token![,]) {
97            let _ = input_expr_content.parse::<Token![,]>()?;
98            Punctuated::<FnArg, Token![,]>::parse_separated_nonempty(&input_expr_content)?
99                .into_iter()
100                .collect()
101        } else {
102            Vec::new()
103        };
104        let _ = input.parse::<Token![=>]>()?;
105        let body = input.parse()?;
106        Ok(Self {
107            default,
108            generic_params,
109            input_expr,
110            extra_args,
111            body,
112        })
113    }
114}
115
116/// This is entry point for handling arguments of `specialized_dispatch` macro. It parses arguments
117/// of the specialized dispatch macro and expands to the corresponding implementation.
118#[derive(Debug, Eq, PartialEq)]
119struct SpecializedDispatchExpr {
120    from_type: Type,
121    to_type: Type,
122    arms: Vec<DispatchArmExpr>,
123    input_expr: Expr,
124    extra_args: Vec<Expr>,
125}
126
127/// Parses specialization arms as long as they start with `default` or `fn`.
128fn parse_punctuated_arms(input: &ParseStream) -> Result<Punctuated<DispatchArmExpr, Token![,]>> {
129    let mut arms = Punctuated::new();
130    loop {
131        if input.peek(Token![default]) || input.peek(Token![fn]) {
132            arms.push(input.parse()?);
133        } else {
134            break;
135        }
136        if input.peek(Token![,]) && (input.peek2(Token![default]) || input.peek2(Token![fn])) {
137            let _ = input.parse::<Token![,]>()?;
138        } else {
139            break;
140        }
141    }
142    Ok(arms)
143}
144
145impl Parse for SpecializedDispatchExpr {
146    fn parse(input: ParseStream) -> Result<Self> {
147        let from_type = input.parse()?;
148        let _ = input.parse::<Token![->]>()?;
149        let to_type = input.parse()?;
150        let _ = input.parse::<Token![,]>()?;
151        let arms = parse_punctuated_arms(&input)?.into_iter().collect();
152        let _ = input.parse::<Token![,]>()?;
153        let input_expr = input.parse()?;
154        let _ = input.parse::<Token![,]>().ok();
155        let extra_args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?
156            .into_iter()
157            .collect();
158        Ok(Self {
159            from_type,
160            to_type,
161            arms,
162            input_expr,
163            extra_args,
164        })
165    }
166}
167
168/// Generates local helper trait declaration that will be used for specialized dispatch.
169fn generate_trait_declaration(
170    trait_name: &Ident,
171    extra_args: &[FnArg],
172    return_type: &Type,
173) -> TokenStream2 {
174    // TODO(ozars): Consider passing generic types from the default specialization as well.
175    let tpl = Ident::new("T", Span2::mixed_site());
176    quote! {
177        trait #trait_name<#tpl> {
178            fn dispatch(_: #tpl #(, #extra_args)*) -> #return_type;
179        }
180    }
181}
182
183/// Generates implementation of the helper trait for specialized dispatch arms. This covers both
184/// generic case(s) and concrete case(s).
185fn generate_trait_implementation(
186    default: Option<&Token![default]>,
187    trait_name: &Ident,
188    generic_params: Option<&Punctuated<GenericParam, Token![,]>>,
189    FnArg {
190        r#mut: input_expr_mut,
191        name: input_expr_name,
192        ty: input_expr_type,
193    }: &FnArg,
194    extra_args: &[FnArg],
195    return_type: &Type,
196    body: &Expr,
197) -> TokenStream2 {
198    let generics = generic_params.map(|g| quote! {<#g>});
199    quote! {
200        impl #generics #trait_name<#input_expr_type> for #input_expr_type {
201            #default fn dispatch(#input_expr_mut #input_expr_name: #input_expr_type #(, #extra_args)*) -> #return_type {
202                #body
203            }
204        }
205    }
206}
207
208/// Generates the dispatch call to the helper trait.
209fn generate_dispatch_call(
210    from_type: &Type,
211    trait_name: &Ident,
212    input_expr: &Expr,
213    extra_args: &[Expr],
214) -> TokenStream2 {
215    quote! {
216        <#from_type as #trait_name<#from_type>>::dispatch(#input_expr #(, #extra_args)*)
217    }
218}
219
220impl ToTokens for SpecializedDispatchExpr {
221    fn to_tokens(&self, tokens: &mut TokenStream2) {
222        let trait_name = Ident::new("SpecializedDispatchCall", Span2::mixed_site());
223        let mut trait_impls = TokenStream2::new();
224        let mut extra_args = None;
225
226        for arm in &self.arms {
227            if arm.default.is_some() && extra_args.is_none() {
228                extra_args = Some(&arm.extra_args);
229            }
230            trait_impls.extend(generate_trait_implementation(
231                arm.default.as_ref(),
232                &trait_name,
233                arm.generic_params.as_ref(),
234                &arm.input_expr,
235                &arm.extra_args,
236                &self.to_type,
237                &arm.body,
238            ));
239        }
240
241        let trait_decl = generate_trait_declaration(
242            &trait_name,
243            extra_args.unwrap_or(&Vec::new()),
244            &self.to_type,
245        );
246
247        let dispatch_call = generate_dispatch_call(
248            &self.from_type,
249            &trait_name,
250            &self.input_expr,
251            &self.extra_args,
252        );
253
254        tokens.extend(quote! {
255            {
256                #trait_decl
257                #trait_impls
258                #dispatch_call
259            }
260        });
261    }
262}
263
264/// Entry point for the macro. Please see [the crate documentation](`crate`) for
265/// more information and example.
266#[proc_macro]
267pub fn specialized_dispatch(input: TokenStream) -> TokenStream {
268    parse_macro_input!(input as SpecializedDispatchExpr)
269        .into_token_stream()
270        .into()
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use syn::parse_quote;
277
278    #[test]
279    fn parse_arm_with_concrete_type() {
280        let arm: DispatchArmExpr = parse_quote!(fn (v: u8) => format!("u8: {}", v));
281        assert_eq!(
282            arm,
283            DispatchArmExpr {
284                default: None,
285                generic_params: None,
286                input_expr: parse_quote!(v: u8),
287                extra_args: vec![],
288                body: parse_quote!(format!("u8: {}", v)),
289            }
290        );
291    }
292
293    #[test]
294    fn parse_arm_with_generic_type() {
295        let arm: DispatchArmExpr = parse_quote!(default fn <T>(_: T) => format!("default value"));
296        assert_eq!(
297            arm,
298            DispatchArmExpr {
299                default: Some(Default::default()),
300                generic_params: Some(parse_quote!(T)),
301                input_expr: parse_quote!(_: T),
302                extra_args: vec![],
303                body: parse_quote!(format!("default value")),
304            }
305        );
306    }
307
308    #[test]
309    fn parse_specialized_dispatch_expr() {
310        let expr: SpecializedDispatchExpr = parse_quote! {
311            E -> String,
312            default fn <T>(_: T) => format!("default value"),
313            fn (v: u8) => format!("u8: {}", v),
314            fn (v: u16) => format!("u16: {}", v),
315            expr,
316        };
317        assert_eq!(
318            expr,
319            SpecializedDispatchExpr {
320                from_type: parse_quote!(E),
321                to_type: parse_quote!(String),
322                arms: vec![
323                    DispatchArmExpr {
324                        default: Some(Default::default()),
325                        generic_params: Some(parse_quote!(T)),
326                        input_expr: parse_quote!(_: T),
327                        extra_args: vec![],
328                        body: parse_quote!(format!("default value")),
329                    },
330                    DispatchArmExpr {
331                        default: None,
332                        generic_params: None,
333                        input_expr: parse_quote!(v: u8),
334                        extra_args: vec![],
335                        body: parse_quote!(format!("u8: {}", v)),
336                    },
337                    DispatchArmExpr {
338                        default: None,
339                        generic_params: None,
340                        input_expr: parse_quote!(v: u16),
341                        extra_args: vec![],
342                        body: parse_quote!(format!("u16: {}", v)),
343                    },
344                ],
345                input_expr: parse_quote!(expr),
346                extra_args: vec![],
347            }
348        );
349    }
350
351    #[test]
352    fn parse_trailing_args() {
353        let expr: SpecializedDispatchExpr = parse_quote! {
354            E -> String,
355            default fn <T>(_: T, arg1: u8, arg2: u16, arg3: &str) => format!("default value"),
356            fn (v: u8, arg1: u8, arg2: u16, arg3: &str) => format!("u8: {}", v),
357            fn (v: u16, arg1: u8, arg2: u16, arg3: &str) => format!("u16: {}", v),
358            expr,
359            1u8,
360            2u16,
361            "bugun_bayram_erken_kalkin_cocuklar",
362        };
363
364        assert_eq!(
365            expr,
366            SpecializedDispatchExpr {
367                from_type: parse_quote!(E),
368                to_type: parse_quote!(String),
369                arms: vec![
370                    DispatchArmExpr {
371                        default: Some(Default::default()),
372                        generic_params: Some(parse_quote!(T)),
373                        input_expr: parse_quote!(_: T),
374                        extra_args: vec![
375                            parse_quote!(arg1: u8),
376                            parse_quote!(arg2: u16),
377                            parse_quote!(arg3: &str)
378                        ],
379                        body: parse_quote!(format!("default value")),
380                    },
381                    DispatchArmExpr {
382                        default: None,
383                        generic_params: None,
384                        input_expr: parse_quote!(v: u8),
385                        extra_args: vec![
386                            parse_quote!(arg1: u8),
387                            parse_quote!(arg2: u16),
388                            parse_quote!(arg3: &str)
389                        ],
390                        body: parse_quote!(format!("u8: {}", v)),
391                    },
392                    DispatchArmExpr {
393                        default: None,
394                        generic_params: None,
395                        input_expr: parse_quote!(v: u16),
396                        extra_args: vec![
397                            parse_quote!(arg1: u8),
398                            parse_quote!(arg2: u16),
399                            parse_quote!(arg3: &str)
400                        ],
401                        body: parse_quote!(format!("u16: {}", v)),
402                    },
403                ],
404                input_expr: parse_quote!(expr),
405                extra_args: vec![
406                    parse_quote!(1u8),
407                    parse_quote!(2u16),
408                    parse_quote!("bugun_bayram_erken_kalkin_cocuklar")
409                ],
410            }
411        );
412    }
413
414    #[test]
415    fn parse_mut_arg() {
416        let arg: FnArg = parse_quote!(mut v: u8);
417        assert_eq!(
418            arg,
419            FnArg {
420                r#mut: Some(parse_quote!(mut)),
421                ty: parse_quote!(u8),
422                name: FnArgName::Ident(parse_quote!(v)),
423            }
424        );
425    }
426}