Skip to main content

wslink_rs_marco/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{format_ident, quote};
4use syn::{parse_macro_input, FnArg, GenericArgument, Index, ItemFn, LitStr, Path, ReturnType, Type};
5
6#[proc_macro_attribute]
7pub fn wslink_rpc(attr: TokenStream, item: TokenStream) -> TokenStream {
8    let input_fn = parse_macro_input!(item as ItemFn);
9    let fn_name = &input_fn.sig.ident;
10
11    let name = if attr.is_empty() {
12        fn_name.to_string()
13    } else {
14        parse_macro_input!(attr as LitStr).value()
15    };
16
17    let mut is_mut_self = false;
18
19    // collect all parameter types
20    let param_types: Vec<_> = input_fn
21        .sig
22        .inputs
23        .iter()
24        .filter_map(|arg| match arg {
25            FnArg::Receiver(v) => {
26                is_mut_self = v.mutability.is_some();
27                None
28            }
29            FnArg::Typed(pat_type) => Some(&pat_type.ty),
30        })
31        .collect();
32
33    let rpc_name = format_ident!("{}_rpc", fn_name);
34
35    // 1. Extract function return type
36    let return_type = match &input_fn.sig.output {
37        ReturnType::Default => {
38            return syn::Error::new_spanned(input_fn, "函数必须有返回类型")
39                .to_compile_error()
40                .into();
41        }
42        ReturnType::Type(_, ty) => ty,
43    };
44
45    // 2. Parse anyhow::Result<T>
46    let inner_type = parse_anyhow_result(return_type);
47
48    let mut index = Vec::new();
49    for i in 0..param_types.len() {
50        index.push(Index::from(i));
51    }
52
53    let do_result_unwrap = if inner_type == return_type.as_ref() {
54        quote! {}
55    } else {
56        quote! {?}
57    };
58
59    let define_req_type = quote! {
60        #[derive(Deserialize)]
61        struct Req { args: (#(#param_types,)*) }
62    };
63
64    let do_call_rpc_function = match param_types.len() {
65        0 => quote! { #rpc_name()#do_result_unwrap },
66        _ => quote! { #rpc_name(#(d.args.#index,)*)#do_result_unwrap },
67    };
68
69    let do_call_req_decode = match param_types.len() {
70        0 => quote! {},
71        _ => quote! { let d: Req = rmp_serde::decode::from_slice(&d)?; },
72    };
73
74    let downcast = if is_mut_self {
75        quote! { downcast_mut::<Self>() }
76    } else {
77        quote! { downcast_ref::<Self>() }
78    };
79
80    let mut new_input_fn = input_fn.clone();
81    new_input_fn.sig.ident = Ident::new(&format!("{}_rpc", fn_name), fn_name.span());
82
83    let expanded = quote! {
84        // redefine function body
85        fn #fn_name() -> WsLinkRpc {
86            use serde::{Serialize, Deserialize};
87            use rmp_serde;
88            use rmp_serde::Serializer;
89
90            let rpc_func = |object: &mut dyn ServerProtocol, _client_id: usize, _rpc_id: &str, d: &[u8]| -> anyhow::Result<Box<Local2AsyncFn>> {
91                #define_req_type
92
93                #[derive(Serialize)]
94                struct Rsp {
95                    wslink: String,
96                    id: String,
97                    result: #inner_type,
98                }
99
100                #do_call_req_decode
101
102                let r = object
103                    .#downcast
104                    .expect("wslink-rs RPC marco downcast failed")
105                    .#do_call_rpc_function;
106
107                let ff = Box::new(move |rpc_id: &str| -> anyhow::Result<Vec<u8>>{
108                    let res = WsRsp {
109                        wslink: "1.0".into(),
110                        id: rpc_id.into(),
111                        result: r,
112                    };
113
114                    let data = wslink_rs::rmp::to_vec_named(&res)?;
115
116                    Ok(data)
117                });
118
119                Ok(ff)
120            };
121
122            WsLinkRpc::new(#name, Box::new(rpc_func))
123        }
124
125        // original function with new name
126        #new_input_fn
127    };
128
129    TokenStream::from(expanded)
130}
131
132// Parse anyhow::Result<T> and return T
133fn parse_anyhow_result(return_type: &Box<Type>) -> &Type {
134    // Check whether this is a path type (e.g. anyhow::Result<T>)
135    if let Type::Path(type_path) = &**return_type {
136        let path = &type_path.path;
137
138        // Check whether the path matches anyhow::Result or Result
139        if is_anyhow_result_path(path) {
140            // Extract the first generic argument (T)
141            if let Some(segment) = path.segments.last() {
142                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
143                    if let Some(GenericArgument::Type(inner_type)) = args.args.first() {
144                        return inner_type;
145                    }
146                }
147            }
148        }
149    }
150
151    return_type
152}
153
154// Check whether the path matches anyhow::Result or Result (including aliases)
155fn is_anyhow_result_path(path: &Path) -> bool {
156    let segments: Vec<_> = path.segments.iter().map(|s| s.ident.to_string()).collect();
157
158    // Match the following forms:
159    // - anyhow::Result
160    // - Result (when anyhow::Result is imported/aliased)
161    // - some::path::Result (as long as the last segment matches)
162    let last_segment = segments.last().map(|s| s.as_str());
163
164    // Check whether the path ends with "Result"
165    last_segment == Some("Result")
166}