1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{format_ident, quote};
4use syn::{parse_macro_input, FnArg, GenericArgument, Index, ItemFn, LitStr, Path, ReturnType, Type};
5
6#[proc_macro_attribute]
7pub fn wslink_rpc(attr: TokenStream, item: TokenStream) -> TokenStream {
8 let input_fn = parse_macro_input!(item as ItemFn);
9 let fn_name = &input_fn.sig.ident;
10
11 let name = if attr.is_empty() {
12 fn_name.to_string()
13 } else {
14 parse_macro_input!(attr as LitStr).value()
15 };
16
17 let mut is_mut_self = false;
18
19 let param_types: Vec<_> = input_fn
21 .sig
22 .inputs
23 .iter()
24 .filter_map(|arg| match arg {
25 FnArg::Receiver(v) => {
26 is_mut_self = v.mutability.is_some();
27 None
28 }
29 FnArg::Typed(pat_type) => Some(&pat_type.ty),
30 })
31 .collect();
32
33 let rpc_name = format_ident!("{}_rpc", fn_name);
34
35 let return_type = match &input_fn.sig.output {
37 ReturnType::Default => {
38 return syn::Error::new_spanned(input_fn, "函数必须有返回类型")
39 .to_compile_error()
40 .into();
41 }
42 ReturnType::Type(_, ty) => ty,
43 };
44
45 let inner_type = parse_anyhow_result(return_type);
47
48 let mut index = Vec::new();
49 for i in 0..param_types.len() {
50 index.push(Index::from(i));
51 }
52
53 let do_result_unwrap = if inner_type == return_type.as_ref() {
54 quote! {}
55 } else {
56 quote! {?}
57 };
58
59 let define_req_type = quote! {
60 #[derive(Deserialize)]
61 struct Req { args: (#(#param_types,)*) }
62 };
63
64 let do_call_rpc_function = match param_types.len() {
65 0 => quote! { #rpc_name()#do_result_unwrap },
66 _ => quote! { #rpc_name(#(d.args.#index,)*)#do_result_unwrap },
67 };
68
69 let do_call_req_decode = match param_types.len() {
70 0 => quote! {},
71 _ => quote! { let d: Req = rmp_serde::decode::from_slice(&d)?; },
72 };
73
74 let downcast = if is_mut_self {
75 quote! { downcast_mut::<Self>() }
76 } else {
77 quote! { downcast_ref::<Self>() }
78 };
79
80 let mut new_input_fn = input_fn.clone();
81 new_input_fn.sig.ident = Ident::new(&format!("{}_rpc", fn_name), fn_name.span());
82
83 let expanded = quote! {
84 fn #fn_name() -> WsLinkRpc {
86 use serde::{Serialize, Deserialize};
87 use rmp_serde;
88 use rmp_serde::Serializer;
89
90 let rpc_func = |object: &mut dyn ServerProtocol, _client_id: usize, _rpc_id: &str, d: &[u8]| -> anyhow::Result<Box<Local2AsyncFn>> {
91 #define_req_type
92
93 #[derive(Serialize)]
94 struct Rsp {
95 wslink: String,
96 id: String,
97 result: #inner_type,
98 }
99
100 #do_call_req_decode
101
102 let r = object
103 .#downcast
104 .expect("wslink-rs RPC marco downcast failed")
105 .#do_call_rpc_function;
106
107 let ff = Box::new(move |rpc_id: &str| -> anyhow::Result<Vec<u8>>{
108 let res = WsRsp {
109 wslink: "1.0".into(),
110 id: rpc_id.into(),
111 result: r,
112 };
113
114 let data = wslink_rs::rmp::to_vec_named(&res)?;
115
116 Ok(data)
117 });
118
119 Ok(ff)
120 };
121
122 WsLinkRpc::new(#name, Box::new(rpc_func))
123 }
124
125 #new_input_fn
127 };
128
129 TokenStream::from(expanded)
130}
131
132fn parse_anyhow_result(return_type: &Box<Type>) -> &Type {
134 if let Type::Path(type_path) = &**return_type {
136 let path = &type_path.path;
137
138 if is_anyhow_result_path(path) {
140 if let Some(segment) = path.segments.last() {
142 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
143 if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
144 return inner_type;
145 }
146 }
147 }
148 }
149 }
150
151 return_type
152}
153
154fn is_anyhow_result_path(path: &Path) -> bool {
156 let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
157
158 let last_segment = segments.last().map(|s| s.as_str());
163
164 last_segment == Some("Result")
166}