1use darling::ast::NestedMeta;
2use darling::FromMeta;
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::ItemFn;
6
7#[derive(FromMeta, PartialEq, Eq, Debug)]
8enum TokioFlavor {
9 MultiThread,
10 CurrentThread,
11}
12
13#[derive(Debug, FromMeta)]
14struct MainArgs {
15 shimkit: Option<syn::Path>,
16 tokio: Option<syn::Path>,
17 flavor: Option<TokioFlavor>,
18 worker_threads: Option<u32>,
19 start_paused: Option<bool>,
20}
21
22#[proc_macro_attribute]
23pub fn main(args: TokenStream, input: TokenStream) -> TokenStream {
24 main_impl(args, input).unwrap_or_else(|err| err.into_compile_error().into())
25}
26
27fn main_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
28 let args = NestedMeta::parse_meta_list(args.into())?;
29 let args = MainArgs::from_list(&args)?;
30 let input: ItemFn = syn::parse(input)?;
31 let ident = input.sig.ident.clone();
32
33 let shimkit_path = args.shimkit.unwrap_or(syn::Path::from_string("shimkit")?);
34 let tokio_path = args.tokio.unwrap_or(syn::Path::from_string("tokio")?);
35
36 let flavor = match args.flavor.unwrap_or(TokioFlavor::CurrentThread) {
37 TokioFlavor::CurrentThread => "new_current_thread",
38 TokioFlavor::MultiThread => "new_multi_thread",
39 };
40
41 let flavor = syn::Ident::from_string(flavor)?;
42
43 let start_paused = match args.start_paused {
44 Some(true) => quote! { .start_paused(true) },
45 _ => quote! {},
46 };
47
48 let worker_threads = match args.worker_threads {
49 Some(n) => quote! { .worker_threads(#n) },
50 _ => quote! {},
51 };
52
53 let tokens = if input.sig.asyncness.is_none() {
54 quote! {
55 fn main() -> impl ::std::process::Termination {
56 #input
57 #shimkit_path::run::run(#ident)
58 }
59 }
60 } else {
61 quote! {
62 fn main() -> impl ::std::process::Termination {
63 fn inner_main(cmd: #shimkit_path::args::Arguments) -> impl ::std::process::Termination {
64 #input
65 #tokio_path::runtime::Builder::#flavor()
66 #worker_threads
67 .enable_all()
68 #start_paused
69 .build()
70 .unwrap()
71 .block_on(#ident(cmd))
72 }
73 #shimkit_path::run::run(inner_main)
74 }
75 }
76 };
77
78 Ok(tokens.into())
79}