tcp_struct_macros/
lib.rs

1extern crate proc_macro;
2use once_cell::sync::Lazy;
3use proc_macro::TokenStream;
4use quote::quote;
5use std::{collections::HashMap, sync::Mutex};
6use syn::{parse_macro_input, DeriveInput, FnArg, ItemImpl, Pat, ReturnType};
7
8static FUNC_REGISTRY: Lazy<Mutex<HashMap<String, HashMap<String, Fn>>>> =
9    Lazy::new(|| Mutex::new(Default::default()));
10
11fn str_to_ident(s: impl ToString) -> syn::Ident {
12    syn::Ident::new(&s.to_string(), proc_macro2::Span::call_site())
13}
14
15#[proc_macro_derive(TCPShare)]
16pub fn derive_answer_fn(input: TokenStream) -> proc_macro::TokenStream {
17    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
18    let name: &syn::Ident = &input.ident;
19    let name_str = name.to_string();
20    let reader_service_name_str = str_to_ident(format!("{}Reader", name_str));
21    let writer_service_name_str = str_to_ident(format!("{}Writer", name_str));
22
23    let elements = FUNC_REGISTRY.lock().unwrap();
24    let elements = elements.get(&name_str);
25    let cases = elements
26        .as_ref()
27        .map(|v| {
28            v.values()
29                .map(|v| v.body.parse::<proc_macro2::TokenStream>().unwrap())
30                .collect::<Vec<_>>()
31        })
32        .unwrap_or_default();
33
34    let elements = elements
35        .map(|v| {
36            v.values()
37                .map(|v| {
38                    let name = str_to_ident(&v.name);
39                    let r_type = &v.ret_type;
40                    let args = v
41                        .args
42                        .iter()
43                        .filter(|v| v.contains(":"))
44                        .map(|arg_str| {
45                            let token_stream: proc_macro2::TokenStream = arg_str.parse().unwrap();
46                            token_stream
47                        })
48                        .collect::<Vec<_>>();
49                    let data: proc_macro2::TokenStream = v.to_data.parse().unwrap();
50
51                    #[cfg(feature = "async-tcp")]
52                    match r_type {
53                        Some(r_type) => {
54                            let r_type: proc_macro2::TokenStream = r_type.parse().unwrap();
55                            quote! {pub async fn #name(&self, #(#args),*) -> Result<#r_type, tcp_struct::Error>{
56                                #data
57                            }}
58                        }
59                        None => quote! {pub async fn #name(&self, #(#args),*) -> Result<(), tcp_struct::Error>{
60                            #data
61                        }},
62                    }
63                    #[cfg(not(feature = "async-tcp"))]
64                    match r_type {
65                        Some(r_type) => {
66                            let r_type: proc_macro2::TokenStream = r_type.parse().unwrap();
67                            quote! {pub fn #name(&self, #(#args),*) -> Result<#r_type, tcp_struct::Error>{
68                                #data
69                            }}
70                        }
71                        None => quote! {pub fn #name(&self, #(#args),*) -> Result<(), tcp_struct::Error>{
72                            #data
73                        }},
74                    }
75
76                })
77                .collect::<Vec<_>>()
78        })
79        .unwrap_or_default();
80    let quote = quote! {
81        pub struct #reader_service_name_str {
82            port: u16,
83            head: String,
84        }
85        pub struct #writer_service_name_str {
86            data: std::sync::Arc<tokio::sync::Mutex<#name>>,
87        }
88
89        impl tcp_struct::Receiver<#name> for #writer_service_name_str {
90            fn request(func: String, data: Vec<u8>, app_data: std::sync::Arc<tokio::sync::Mutex<#name>>) -> std::pin::Pin<Box<dyn std::future::Future<Output = tcp_struct::Result<Vec<u8>>> + Send>> {
91                Box::pin(async move {
92                match func.as_str() {
93                    #(#cases)*
94                    _ => Err(tcp_struct::Error::FunctionNotFound)
95                }
96                })
97            }
98            fn get_app_data(&self) -> std::sync::Arc<tokio::sync::Mutex<#name>> {
99                self.data.clone()
100            }
101        }
102
103        impl #reader_service_name_str {
104            #(#elements)*
105        }
106
107        impl #name {
108            pub fn read(port: u16, head: &str) -> #reader_service_name_str {
109                #reader_service_name_str {
110                    port,
111                    head: head.to_string()
112                }
113            }
114        }
115
116        impl tcp_struct::Starter for #name {
117            async fn start(self, port: u16, header: &str) -> std::io::Result<()> {
118                use tcp_struct::Receiver as _;
119                #writer_service_name_str {
120                    data: std::sync::Arc::new(tokio::sync::Mutex::new(self))
121                }.start(port, header).await
122            }
123
124            async fn start_from_listener(self, listener: tcp_struct::TcpListener, header: &str) -> std::io::Result<()> {
125                use tcp_struct::Receiver as _;
126                #writer_service_name_str {
127                    data: std::sync::Arc::new(tokio::sync::Mutex::new(self))
128                }.start_from_listener(listener, header).await
129            }
130        }
131    };
132    quote.into()
133}
134
135#[proc_macro_attribute]
136pub fn register_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
137    let input = parse_macro_input!(item as ItemImpl);
138
139    let struct_name = if let syn::Type::Path(type_path) = &*input.self_ty {
140        if let Some(segment) = type_path.path.segments.last() {
141            segment.ident.to_string()
142        } else {
143            "Unknown".to_string()
144        }
145    } else {
146        "Unknown".to_string()
147    };
148
149    let func_names: Vec<Fn> = input
150        .items
151        .iter()
152        .filter_map(|item| {
153            if let syn::ImplItem::Method(method) = item {
154                let asyncronous = method.sig.asyncness.is_some();
155                let name = &method.sig.ident;
156                let name_str = name.to_string();
157                let args: Vec<String> = method
158                    .sig
159                    .inputs
160                    .iter()
161                    .map(|arg| quote!(#arg).to_string())
162                    .collect();
163
164                let arg_names = method.sig.inputs.iter().flat_map(|v|{
165                    match v {
166                        syn::FnArg::Receiver(_) => None,
167                        syn::FnArg::Typed(pat_type) => {
168                            if let Pat::Ident(pat_ident) = &*pat_type.pat {
169                                Some(&pat_ident.ident)
170                            } else {
171                                None
172                            }
173                        },
174                    }
175                }).collect::<Vec<_>>();
176                let to_data = match arg_names.is_empty() {
177                    false => quote! {tcp_struct::encode((#(#arg_names),*))?},
178                    true => quote! {vec![]},
179                };
180                #[cfg(feature = "async-tcp")]
181                let to_data = quote! {Ok(tcp_struct::decode(&tcp_struct::send_data(self.port,&self.head,#name_str,#to_data).await?)?)}.to_string();
182                #[cfg(not(feature = "async-tcp"))]
183                let to_data = quote! {Ok(tcp_struct::decode(&tcp_struct::send_data(self.port,&self.head,#name_str,#to_data)?)?)}.to_string();
184                let arg_types: Vec<_> = method.sig.inputs
185                    .iter()
186                    .filter_map(|arg| {
187                        if let FnArg::Typed(pat_type) = arg {
188                            let ty = &pat_type.ty;
189                            Some(quote! {#ty})
190                        } else {
191                            None
192                        }
193                    })
194                    .collect();
195
196                let seperator = match arg_names.len() == args.len() {
197                    true => {
198                        let struct_name = str_to_ident(&struct_name);
199                        quote! {#struct_name::#name}},
200                    false => quote! {app_data.lock().await.#name},
201                };
202
203                let load = match arg_names.is_empty() {
204                    true => quote! {},
205                    false => quote! {let (#(#arg_names),*) = tcp_struct::decode::<(#(#arg_types),*)>(&data)?;}
206                };
207
208                let asyncronous = match asyncronous {
209                    true => quote! {.await},
210                    false => quote! {}
211                };
212
213                let body = quote! {
214                    #name_str => {
215                        #load
216                        tcp_struct::encode(#seperator(#(#arg_names),*)#asyncronous)
217                    },
218                }.to_string();
219
220                Some(Fn {
221                    body,
222                    to_data,
223                    args,
224                    ret_type: match &method.sig.output {
225                        ReturnType::Default => None,
226                        ReturnType::Type(_, ty) => Some(format!("{}", quote!(#ty))),
227                    },
228                    name: name_str,
229                })
230            } else {
231                None
232            }
233        })
234        .collect();
235
236    {
237        let mut registry = FUNC_REGISTRY.lock().unwrap();
238        let map = registry.entry(struct_name).or_default();
239        for func_name in func_names {
240            map.insert(func_name.name.clone(), func_name);
241        }
242    }
243
244    TokenStream::from(quote!(#input))
245}
246
247struct Fn {
248    args: Vec<String>,
249    ret_type: Option<String>,
250    name: String,
251    body: String,
252    to_data: String,
253}