1#![feature(iter_array_chunks)]
2#![feature(let_chains)]
3
4use convert_case::{Case, Casing};
5use proc_macro::TokenStream as TokenStreamV1;
6use proc_macro2::{Delimiter, Ident, TokenStream as TokenStreamV2, TokenTree};
7use quote::{format_ident, quote, TokenStreamExt};
8
9#[allow(dead_code)]
10#[derive(Debug, Clone, Copy)]
11enum ThunkType {
12 Default,
13 MessagePack,
14}
15
16#[derive(Debug)]
17struct FnData {
18 is_async: bool,
19 name: Ident,
20 return_type: Option<Ident>,
21}
22impl FnData {
23 fn get_fn_name(token_stream: TokenStreamV2) -> Result<(Ident, bool), ()> {
24 let mut tokens_iter = token_stream.into_iter();
25
26 let mut is_next_token_fn_name = false;
27 let mut is_async = false;
28 let fn_name = tokens_iter
29 .find(|token_tree| {
30 if is_next_token_fn_name {
31 return true;
32 }
33 if let TokenTree::Ident(ident) = token_tree {
34 if ident == "async" {
35 is_async = true;
36 }
37 if ident == "fn" {
38 is_next_token_fn_name = true;
39 return false;
40 } else {
41 return false;
42 }
43 }
44 false
45 })
46 .ok_or(())?;
47
48 if let TokenTree::Ident(ident) = fn_name {
49 Ok((ident, is_async))
50 } else {
51 Err(())
52 }
53 }
54
55 fn get_fn_return_type(token_stream: TokenStreamV2) -> Option<Ident> {
56 let tokens_iter = token_stream.into_iter();
57
58 let mut return_type_token_index = None;
59 let mut is_next_token_return_type = false;
60 let return_type = tokens_iter.array_chunks::<2>().find(|[token1, token2]| {
61 if is_next_token_return_type {
62 return true;
63 }
64 if let TokenTree::Punct(punct1) = token1 && let TokenTree::Punct(punct2) = token2 {
65 let p1_char = punct1.as_char();
66 let p2_char = punct2.as_char();
67
68 if p1_char == '-' && p2_char == '>' {
69 is_next_token_return_type = true;
70 return_type_token_index = Some(0);
71 return false;
72 } else {
73 return false;
74 }
75 }
76 else if let TokenTree::Punct(punct) = token1 && let TokenTree::Ident(_) = token2 {
77 let p_char = punct.as_char();
78
79 if p_char == '>' {
80 return_type_token_index = Some(1);
81 return true;
82 }
83 }
84 false
85 })?;
86
87 if let TokenTree::Ident(return_type) = return_type[return_type_token_index.unwrap()].clone()
88 {
89 Some(return_type)
90 } else {
91 None
92 }
93 }
94
95 fn from_token_stream(token_stream: TokenStreamV2) -> Option<Self> {
96 let return_type = Self::get_fn_return_type(token_stream.clone());
97 let (fn_name, is_async) = Self::get_fn_name(token_stream).ok()?;
98
99 Some(Self {
100 is_async,
101 name: fn_name,
102 return_type,
103 })
104 }
105}
106
107fn generate_struct(
108 fn_name: &Ident,
109 mut tokens_iter: impl Iterator<Item = TokenTree>,
110) -> Option<(TokenStreamV2, Ident)> {
111 let struct_name = format_ident!("{}Args", fn_name.to_string().to_case(Case::Pascal));
112
113 let fn_args_tokens = {
114 let fn_args_group = tokens_iter.find(|token_tree| {
115 if let TokenTree::Group(group) = token_tree {
116 group.delimiter() == Delimiter::Parenthesis
117 } else {
118 false
119 }
120 })?;
121 if let TokenTree::Group(group) = fn_args_group {
122 group.stream()
123 } else {
124 return None;
125 }
126 };
127
128 Some((
129 quote! {
130 #[derive(Serialize, Deserialize, Debug)]
131 struct #struct_name {
132 #fn_args_tokens
133 }
134 },
135 struct_name,
136 ))
137}
138
139fn get_struct_field_names(tokens_iter: impl Iterator<Item = TokenTree>) -> Option<TokenStreamV2> {
140 let mut should_filter_next = false;
141
142 let variable_names_tokens = tokens_iter
143 .filter(|token_tree| {
144 if should_filter_next {
145 should_filter_next = false;
146 return false;
147 }
148 if let TokenTree::Punct(punct) = token_tree {
149 if punct.as_char() == ':' {
150 should_filter_next = true;
151 return false;
152 } else {
153 return true;
154 }
155 }
156 true
157 })
158 .collect::<TokenStreamV2>();
159
160 if variable_names_tokens.is_empty() {
161 None
162 } else {
163 Some(variable_names_tokens)
164 }
165}
166
167fn get_struct_fields(mut tokens_iter: impl Iterator<Item = TokenTree>) -> Option<TokenStreamV2> {
168 let struct_fields_group = tokens_iter.find(|token_tree| {
169 if let TokenTree::Group(group) = token_tree {
170 group.delimiter() == Delimiter::Brace
171 } else {
172 false
173 }
174 })?;
175
176 if let TokenTree::Group(group) = struct_fields_group {
177 let stream = group.stream();
178
179 if stream.is_empty() {
180 None
181 } else {
182 Some(stream)
183 }
184 } else {
185 None
186 }
187}
188
189fn generate_thunk(
190 fn_data: &FnData,
191 struct_name: &Ident,
192 tokens_iter: impl Iterator<Item = TokenTree>,
193 thunk_type: ThunkType,
194) -> Option<TokenStreamV2> {
195 let FnData {
196 is_async,
197 name,
198 return_type,
199 } = fn_data;
200
201 let thunk_name = match thunk_type {
202 ThunkType::Default => format_ident!("{}_thunk", name),
203 ThunkType::MessagePack => format_ident!("{}_messagepack_thunk", name),
204 };
205
206 let struct_fields_tokens = get_struct_fields(tokens_iter);
207
208 let variable_names_tokens = if struct_fields_tokens.is_some() {
209 get_struct_field_names(struct_fields_tokens?.into_iter())
210 } else {
211 None
212 };
213
214 let fn_prefix = if *is_async {
215 quote!(async fn)
216 } else {
217 quote!(fn)
218 };
219
220 let args_token_stream = if variable_names_tokens.is_none() {
221 quote!(())
222 } else {
223 match thunk_type {
224 ThunkType::Default => quote!((args: #struct_name)),
225 ThunkType::MessagePack => quote!((bytes: &[u8])),
226 }
227 };
228
229 let return_type_stream = if return_type.is_none() {
230 quote!()
231 } else {
232 quote!(-> #return_type)
233 };
234
235 let struct_unwrap_tokens = if variable_names_tokens.is_none() {
236 quote!()
237 } else {
238 quote!(let #struct_name { #variable_names_tokens } = args;)
239 };
240
241 let mut call_token_stream = if *is_async {
242 quote!(#name(#variable_names_tokens).await)
243 } else {
244 quote!(#name(#variable_names_tokens))
245 };
246 if return_type.is_none() {
247 call_token_stream.append_all(quote!(;));
248 }
249
250 match thunk_type {
251 ThunkType::Default => Some(quote! {
252 #fn_prefix #thunk_name #args_token_stream #return_type_stream {
253 #struct_unwrap_tokens
254 #call_token_stream
255 }
256 }),
257 ThunkType::MessagePack => {
258 if variable_names_tokens.is_some() {
259 Some(quote! {
260 #fn_prefix #thunk_name #args_token_stream #return_type_stream {
261 let args = rmp_serde::from_slice(bytes).unwrap();
262 #struct_unwrap_tokens
263 #call_token_stream
264 }
265 })
266 } else {
267 None
268 }
269 }
270 }
271}
272
273#[proc_macro_attribute]
274pub fn server_function(_attr: TokenStreamV1, item: TokenStreamV1) -> TokenStreamV1 {
275 let item = Into::<TokenStreamV2>::into(item);
276 let mut item_iter = item.clone().into_iter();
277
278 let fn_data =
279 FnData::from_token_stream(item.clone()).expect("Failed to extract function data!");
280 let (args_struct, args_struct_name) = generate_struct(&fn_data.name, &mut item_iter)
281 .expect("Failed to generate function arguments struct!");
282 let thunk = generate_thunk(
283 &fn_data,
284 &args_struct_name,
285 args_struct.clone().into_iter(),
286 ThunkType::Default,
287 )
288 .expect("Failed to generate function thunk!");
289
290 #[cfg(not(feature = "messagepack"))]
291 return quote! {
292 #args_struct
293 #thunk
294
295 #item
296 }
297 .into();
298
299 #[cfg(feature = "messagepack")]
300 let messagepack_thunk = generate_thunk(
301 &fn_data,
302 &args_struct_name,
303 args_struct.clone().into_iter(),
304 ThunkType::MessagePack,
305 );
306 #[cfg(feature = "messagepack")]
307 quote! {
308 #args_struct
309 #thunk
310 #messagepack_thunk
311
312 #item
313 }
314 .into()
315}