1#![deny(missing_docs)]
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::parse::Parser;
9use syn::parse_macro_input;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{Expr, ExprLit, FnArg, ItemFn, Lit, MetaNameValue, Pat, PatIdent, Token};
13
14#[derive(Default)]
15struct MainArgs {
16 shards: Option<Expr>,
17 backend: Option<BackendArg>,
18}
19
20#[derive(Clone, Copy)]
21enum BackendArg {
22 IoUring,
23}
24
25impl BackendArg {
26 fn parse(value: &Expr) -> syn::Result<Self> {
27 let Expr::Lit(ExprLit {
28 lit: Lit::Str(lit), ..
29 }) = value
30 else {
31 return Err(syn::Error::new(
32 value.span(),
33 "backend must be a string literal: \"io_uring\"",
34 ));
35 };
36
37 match lit.value().as_str() {
38 "io_uring" => Ok(Self::IoUring),
39 other => Err(syn::Error::new(
40 lit.span(),
41 format!("unsupported backend '{other}'; expected \"io_uring\""),
42 )),
43 }
44 }
45
46 fn as_tokens(self) -> proc_macro2::TokenStream {
47 match self {
48 Self::IoUring => quote!(::spargio::BackendKind::IoUring),
49 }
50 }
51}
52
53impl MainArgs {
54 fn parse(args: TokenStream) -> syn::Result<Self> {
55 let mut out = Self::default();
56 let parser = Punctuated::<MetaNameValue, Token![,]>::parse_terminated;
57 let args = parser.parse(args)?;
58 for arg in args {
59 if arg.path.is_ident("shards") {
60 if out.shards.is_some() {
61 return Err(syn::Error::new(
62 arg.path.span(),
63 "duplicate 'shards' option",
64 ));
65 }
66 out.shards = Some(arg.value);
67 continue;
68 }
69 if arg.path.is_ident("backend") {
70 if out.backend.is_some() {
71 return Err(syn::Error::new(
72 arg.path.span(),
73 "duplicate 'backend' option",
74 ));
75 }
76 out.backend = Some(BackendArg::parse(&arg.value)?);
77 continue;
78 }
79 return Err(syn::Error::new(
80 arg.path.span(),
81 "unsupported option; expected one of: shards, backend",
82 ));
83 }
84 Ok(out)
85 }
86}
87
88#[proc_macro_attribute]
89pub fn main(args: TokenStream, item: TokenStream) -> TokenStream {
98 let args = match MainArgs::parse(args) {
99 Ok(args) => args,
100 Err(err) => return err.to_compile_error().into(),
101 };
102
103 let input = parse_macro_input!(item as ItemFn);
104 if input.sig.asyncness.is_none() {
105 return syn::Error::new(
106 input.sig.fn_token.span(),
107 "#[spargio::main] can only be used on async functions",
108 )
109 .to_compile_error()
110 .into();
111 }
112 let inject_handle = match input.sig.inputs.len() {
113 0 => None,
114 1 => {
115 let Some(arg) = input.sig.inputs.first() else {
116 return syn::Error::new(input.sig.inputs.span(), "missing function parameter")
117 .to_compile_error()
118 .into();
119 };
120 match arg {
121 FnArg::Typed(pat_type) => match pat_type.pat.as_ref() {
122 Pat::Ident(PatIdent { .. }) => Some(()),
123 _ => {
124 return syn::Error::new(
125 pat_type.pat.span(),
126 "#[spargio::main] parameter must be an identifier binding",
127 )
128 .to_compile_error()
129 .into();
130 }
131 },
132 FnArg::Receiver(receiver) => {
133 return syn::Error::new(
134 receiver.span(),
135 "#[spargio::main] does not support method receivers",
136 )
137 .to_compile_error()
138 .into();
139 }
140 }
141 }
142 _ => {
143 return syn::Error::new(
144 input.sig.inputs.span(),
145 "#[spargio::main] supports at most one function parameter (RuntimeHandle)",
146 )
147 .to_compile_error()
148 .into();
149 }
150 };
151 if !input.sig.generics.params.is_empty() {
152 return syn::Error::new(
153 input.sig.generics.span(),
154 "#[spargio::main] does not support generic parameters",
155 )
156 .to_compile_error()
157 .into();
158 }
159
160 let attrs = input.attrs;
161 let vis = input.vis;
162 let name = input.sig.ident;
163 let inputs = input.sig.inputs;
164 let output = input.sig.output;
165 let block = input.block;
166 let inner_name = syn::Ident::new(&format!("__spargio_async_{}", name), name.span());
167
168 let shards_builder = args
169 .shards
170 .map(|expr| quote!(.shards(#expr)))
171 .unwrap_or_default();
172 let backend_builder = args
173 .backend
174 .map(|backend| {
175 let backend = backend.as_tokens();
176 quote!(.backend(#backend))
177 })
178 .unwrap_or_default();
179
180 let call_inner = if inject_handle.is_some() {
181 quote!(#inner_name(__spargio_handle).await)
182 } else {
183 quote!(#inner_name().await)
184 };
185
186 quote! {
187 #(#attrs)*
188 #vis fn #name() #output {
189 let __spargio_builder = ::spargio::Runtime::builder()
190 #shards_builder
191 #backend_builder;
192 match ::spargio::__private::block_on(::spargio::run_with(__spargio_builder, |__spargio_handle| async move { #call_inner })) {
193 Ok(__spargio_out) => __spargio_out,
194 Err(::spargio::RuntimeError::UnsupportedBackend(__spargio_msg)) => {
195 panic!(
196 "spargio::main backend is not supported on this platform: {}",
197 __spargio_msg
198 )
199 }
200 Err(__spargio_err) => {
201 panic!("spargio::main runtime startup failed: {:?}", __spargio_err)
202 }
203 }
204 }
205
206 async fn #inner_name(#inputs) #output {
207 #block
208 }
209 }
210 .into()
211}