server_function/
lib.rs

1#![feature(iter_array_chunks)]
2#![feature(let_chains)]
3
4use convert_case::{Case, Casing};
5use proc_macro::TokenStream as TokenStreamV1;
6use proc_macro2::{Delimiter, Ident, TokenStream as TokenStreamV2, TokenTree};
7use quote::{format_ident, quote, TokenStreamExt};
8
9#[allow(dead_code)]
10#[derive(Debug, Clone, Copy)]
11enum ThunkType {
12    Default,
13    MessagePack,
14}
15
16#[derive(Debug)]
17struct FnData {
18    is_async: bool,
19    name: Ident,
20    return_type: Option<Ident>,
21}
22impl FnData {
23    fn get_fn_name(token_stream: TokenStreamV2) -> Result<(Ident, bool), ()> {
24        let mut tokens_iter = token_stream.into_iter();
25
26        let mut is_next_token_fn_name = false;
27        let mut is_async = false;
28        let fn_name = tokens_iter
29            .find(|token_tree| {
30                if is_next_token_fn_name {
31                    return true;
32                }
33                if let TokenTree::Ident(ident) = token_tree {
34                    if ident == "async" {
35                        is_async = true;
36                    }
37                    if ident == "fn" {
38                        is_next_token_fn_name = true;
39                        return false;
40                    } else {
41                        return false;
42                    }
43                }
44                false
45            })
46            .ok_or(())?;
47
48        if let TokenTree::Ident(ident) = fn_name {
49            Ok((ident, is_async))
50        } else {
51            Err(())
52        }
53    }
54
55    fn get_fn_return_type(token_stream: TokenStreamV2) -> Option<Ident> {
56        let tokens_iter = token_stream.into_iter();
57
58        let mut return_type_token_index = None;
59        let mut is_next_token_return_type = false;
60        let return_type = tokens_iter.array_chunks::<2>().find(|[token1, token2]| {
61            if is_next_token_return_type {
62                return true;
63            }
64            if let TokenTree::Punct(punct1) = token1 && let TokenTree::Punct(punct2) = token2 {
65                let p1_char = punct1.as_char();
66                let p2_char = punct2.as_char();
67
68                if p1_char == '-' && p2_char == '>' {
69                    is_next_token_return_type = true;
70                    return_type_token_index = Some(0);
71                    return false;
72                } else {
73                    return false;
74                }
75            }
76            else if let TokenTree::Punct(punct) = token1 && let TokenTree::Ident(_) = token2 {
77                let p_char = punct.as_char();
78
79                if p_char == '>' {
80                    return_type_token_index = Some(1);
81                    return true;
82                }
83            }
84            false
85        })?;
86
87        if let TokenTree::Ident(return_type) = return_type[return_type_token_index.unwrap()].clone()
88        {
89            Some(return_type)
90        } else {
91            None
92        }
93    }
94
95    fn from_token_stream(token_stream: TokenStreamV2) -> Option<Self> {
96        let return_type = Self::get_fn_return_type(token_stream.clone());
97        let (fn_name, is_async) = Self::get_fn_name(token_stream).ok()?;
98
99        Some(Self {
100            is_async,
101            name: fn_name,
102            return_type,
103        })
104    }
105}
106
107fn generate_struct(
108    fn_name: &Ident,
109    mut tokens_iter: impl Iterator<Item = TokenTree>,
110) -> Option<(TokenStreamV2, Ident)> {
111    let struct_name = format_ident!("{}Args", fn_name.to_string().to_case(Case::Pascal));
112
113    let fn_args_tokens = {
114        let fn_args_group = tokens_iter.find(|token_tree| {
115            if let TokenTree::Group(group) = token_tree {
116                group.delimiter() == Delimiter::Parenthesis
117            } else {
118                false
119            }
120        })?;
121        if let TokenTree::Group(group) = fn_args_group {
122            group.stream()
123        } else {
124            return None;
125        }
126    };
127
128    Some((
129        quote! {
130            #[derive(Serialize, Deserialize, Debug)]
131            struct #struct_name {
132                #fn_args_tokens
133            }
134        },
135        struct_name,
136    ))
137}
138
139fn get_struct_field_names(tokens_iter: impl Iterator<Item = TokenTree>) -> Option<TokenStreamV2> {
140    let mut should_filter_next = false;
141
142    let variable_names_tokens = tokens_iter
143        .filter(|token_tree| {
144            if should_filter_next {
145                should_filter_next = false;
146                return false;
147            }
148            if let TokenTree::Punct(punct) = token_tree {
149                if punct.as_char() == ':' {
150                    should_filter_next = true;
151                    return false;
152                } else {
153                    return true;
154                }
155            }
156            true
157        })
158        .collect::<TokenStreamV2>();
159
160    if variable_names_tokens.is_empty() {
161        None
162    } else {
163        Some(variable_names_tokens)
164    }
165}
166
167fn get_struct_fields(mut tokens_iter: impl Iterator<Item = TokenTree>) -> Option<TokenStreamV2> {
168    let struct_fields_group = tokens_iter.find(|token_tree| {
169        if let TokenTree::Group(group) = token_tree {
170            group.delimiter() == Delimiter::Brace
171        } else {
172            false
173        }
174    })?;
175
176    if let TokenTree::Group(group) = struct_fields_group {
177        let stream = group.stream();
178
179        if stream.is_empty() {
180            None
181        } else {
182            Some(stream)
183        }
184    } else {
185        None
186    }
187}
188
189fn generate_thunk(
190    fn_data: &FnData,
191    struct_name: &Ident,
192    tokens_iter: impl Iterator<Item = TokenTree>,
193    thunk_type: ThunkType,
194) -> Option<TokenStreamV2> {
195    let FnData {
196        is_async,
197        name,
198        return_type,
199    } = fn_data;
200
201    let thunk_name = match thunk_type {
202        ThunkType::Default => format_ident!("{}_thunk", name),
203        ThunkType::MessagePack => format_ident!("{}_messagepack_thunk", name),
204    };
205
206    let struct_fields_tokens = get_struct_fields(tokens_iter);
207
208    let variable_names_tokens = if struct_fields_tokens.is_some() {
209        get_struct_field_names(struct_fields_tokens?.into_iter())
210    } else {
211        None
212    };
213
214    let fn_prefix = if *is_async {
215        quote!(async fn)
216    } else {
217        quote!(fn)
218    };
219
220    let args_token_stream = if variable_names_tokens.is_none() {
221        quote!(())
222    } else {
223        match thunk_type {
224            ThunkType::Default => quote!((args: #struct_name)),
225            ThunkType::MessagePack => quote!((bytes: &[u8])),
226        }
227    };
228
229    let return_type_stream = if return_type.is_none() {
230        quote!()
231    } else {
232        quote!(-> #return_type)
233    };
234
235    let struct_unwrap_tokens = if variable_names_tokens.is_none() {
236        quote!()
237    } else {
238        quote!(let #struct_name { #variable_names_tokens } = args;)
239    };
240
241    let mut call_token_stream = if *is_async {
242        quote!(#name(#variable_names_tokens).await)
243    } else {
244        quote!(#name(#variable_names_tokens))
245    };
246    if return_type.is_none() {
247        call_token_stream.append_all(quote!(;));
248    }
249
250    match thunk_type {
251        ThunkType::Default => Some(quote! {
252            #fn_prefix #thunk_name #args_token_stream #return_type_stream {
253                #struct_unwrap_tokens
254                #call_token_stream
255            }
256        }),
257        ThunkType::MessagePack => {
258            if variable_names_tokens.is_some() {
259                Some(quote! {
260                    #fn_prefix #thunk_name #args_token_stream #return_type_stream {
261                        let args = rmp_serde::from_slice(bytes).unwrap();
262                        #struct_unwrap_tokens
263                        #call_token_stream
264                    }
265                })
266            } else {
267                None
268            }
269        }
270    }
271}
272
273#[proc_macro_attribute]
274pub fn server_function(_attr: TokenStreamV1, item: TokenStreamV1) -> TokenStreamV1 {
275    let item = Into::<TokenStreamV2>::into(item);
276    let mut item_iter = item.clone().into_iter();
277
278    let fn_data =
279        FnData::from_token_stream(item.clone()).expect("Failed to extract function data!");
280    let (args_struct, args_struct_name) = generate_struct(&fn_data.name, &mut item_iter)
281        .expect("Failed to generate function arguments struct!");
282    let thunk = generate_thunk(
283        &fn_data,
284        &args_struct_name,
285        args_struct.clone().into_iter(),
286        ThunkType::Default,
287    )
288    .expect("Failed to generate function thunk!");
289
290    #[cfg(not(feature = "messagepack"))]
291    return quote! {
292        #args_struct
293        #thunk
294
295        #item
296    }
297    .into();
298
299    #[cfg(feature = "messagepack")]
300    let messagepack_thunk = generate_thunk(
301        &fn_data,
302        &args_struct_name,
303        args_struct.clone().into_iter(),
304        ThunkType::MessagePack,
305    );
306    #[cfg(feature = "messagepack")]
307    quote! {
308        #args_struct
309        #thunk
310        #messagepack_thunk
311
312        #item
313    }
314    .into()
315}