1#![forbid(unsafe_code)]
2
3use std::borrow::Cow;
4
5use proc_macro::TokenStream;
6use proc_macro_error::*;
7use quote::{format_ident, quote, ToTokens};
8use syn::{spanned::Spanned, *};
9
10#[proc_macro_error]
11#[proc_macro_attribute]
12pub fn rpc_trait(_args: TokenStream, trait_body: TokenStream) -> TokenStream {
13 rpc_define(trait_body)
14}
15
16#[proc_macro_error]
18#[proc_macro]
19pub fn rpc_define(trait_body: TokenStream) -> TokenStream {
20 let root: Path = parse_quote! { ::tiny_rpc::rpc::re_export }; let trait_body = parse_macro_input!(trait_body as ItemTrait);
23 let ident = &trait_body.ident;
24 let func_list = gen_func_list(&trait_body);
25 let (req_rsp_body, req_ident, rsp_ident) =
26 gen_req_rsp(&root, &trait_body.vis, ident, &func_list);
27 let server_body = gen_server(&root, ident, &req_ident, &rsp_ident, &func_list);
28 let client_body = gen_client(&root, ident, &req_ident, &rsp_ident, &func_list);
29
30 let ret = quote! {
31 #trait_body
32 #req_rsp_body
33 #server_body
34 #client_body
35 };
36 if option_env!("RUST_TRACE_MACROS").is_some() {
37 println!("{}", ret);
38 }
39 ret.into()
40}
41
42fn is_ref_receiver(arg: Option<&FnArg>) -> bool {
44 let arg = match arg {
45 Some(arg) => arg,
46 None => return false,
47 };
48
49 match arg {
50 FnArg::Receiver(receiver) => receiver.reference.is_some() && receiver.mutability.is_none(), FnArg::Typed(PatType { pat, ty, .. }) => {
52 matches!(pat.as_ref(), Pat::Ident(ident) if ident.ident == "self") && matches!(
54 ty.as_ref(),
55 Type::Reference(TypeReference{ mutability, elem, ..})
56 if mutability.is_none() && matches!(elem.as_ref(), Type::Path(path) if path.qself.is_none() && path.path.is_ident("Self")) )
58 }
59 }
60}
61
62fn gen_func_list(trait_body: &ItemTrait) -> Vec<Cow<'_, TraitItemMethod>> {
70 let ref_receiver: FnArg = parse_quote!(&self); trait_body
73 .items
74 .iter()
75 .filter_map(|item| match item {
76 TraitItem::Method(method) => {
77 let mut method = Cow::Borrowed(method);
80 if method.default.is_some() {
81 emit_error!(
82 method.default,
83 "trait method can't have default implementation"
84 );
85
86 let mut dummy = method.into_owned();
88 dummy.semi_token = Some(Token));
89 dummy.default = None;
90 method = Cow::Owned(dummy);
91 }
92
93 if !is_ref_receiver(method.sig.inputs.first()) {
94 emit_error!(method, "trait method must have `&self` receiver");
95
96 let mut dummy = method.into_owned();
98 match dummy.sig.inputs.first() {
99 Some(FnArg::Receiver(_)) => {
100 *(dummy
102 .sig
103 .inputs
104 .first_mut()
105 .expect("infallible: non-mutable use before")) =
106 ref_receiver.clone();
107 }
108 Some(FnArg::Typed(PatType { pat, .. })) => match &**pat {
109 Pat::Ident(PatIdent { ident, .. }) if ident == "self" => {
110 *(dummy
112 .sig
113 .inputs
114 .first_mut()
115 .expect("infallible: non-mutable use before")) =
116 ref_receiver.clone();
117 }
118 _ => {
119 dummy.sig.inputs.insert(0, ref_receiver.clone());
121 }
122 },
123 None => {
124 dummy.sig.inputs.insert(0, ref_receiver.clone());
126 }
127 }
128 method = Cow::Owned(dummy);
129 }
130
131 for i in 0..(method.sig.inputs.len()) {
132 if let FnArg::Typed(PatType { ref pat, .. }) = method.sig.inputs[i] {
133 match pat.as_ref() {
134 Pat::Ident(_) => {}
135 other => {
136 emit_error!(other, "trait method cannot use pattern as argument");
137
138 let dummy_ident = format_ident!("__dummy_{:x}", {
139 use std::hash::{Hash, Hasher};
140
141 let mut h =
142 std::collections::hash_map::DefaultHasher::default();
143 other.hash(&mut h);
144 h.finish()
145 });
146
147 let new_pat = Box::new(Pat::Ident(PatIdent {
148 ident: dummy_ident,
149 attrs: Default::default(),
150 by_ref: None,
151 mutability: None,
152 subpat: None,
153 }));
154
155 let mut dummy = method.into_owned();
156 match dummy.sig.inputs[i] {
157 FnArg::Typed(PatType { ref mut pat, .. }) => *pat = new_pat,
158 _ => unreachable!(),
159 }
160 method = Cow::Owned(dummy);
161 }
162 }
163 }
164 }
165
166 for lifetime in method.sig.generics.lifetimes() {
167 if lifetime.lifetime.ident != "req" {
168 emit_error!(
169 lifetime.lifetime.span(),
170 "trait method may only have one lifetime parameter called `'req`"
171 );
172 }
173 }
174
175 Some(method)
176 }
177 item => {
178 emit_error!(
179 item,
180 "#[rpc_define] trait cannot have any item other than function"
181 );
182 None
183 }
184 })
185 .collect::<Vec<_>>()
186}
187
188fn gen_req_rsp<'a>(
189 root: &Path,
190 vis: &Visibility,
191 ident: &Ident,
192 func_list: &[Cow<'a, TraitItemMethod>],
193) -> (proc_macro2::TokenStream, Ident, Ident) {
194 let unit_type = parse_quote!(()); let serde_borrow_attr: Attribute = parse_quote!(#[serde(borrow)]); let req_ident = format_ident!("{}Request", ident);
198 let rsp_ident = format_ident!("{}Response", ident);
199 let serde_path = format!("{}::serde", root.to_token_stream());
200 let serde_path = LitStr::new(serde_path.as_str(), root.span());
201
202 let func_ident = func_list
203 .iter()
204 .map(|method| &method.sig.ident)
205 .collect::<Vec<_>>();
206 let input_type = func_list.iter().map(|method| {
207 method
208 .sig
209 .inputs
210 .iter()
211 .skip(1) .map(|input| match input {
213 FnArg::Typed(PatType { ty, .. }) => ty,
214 FnArg::Receiver(_) => unreachable!(),
215 })
216 .collect::<Vec<_>>()
217 });
218 let input_borrow = func_list.iter().map(|method| {
219 if method.sig.generics.lifetimes().next().is_some() {
220 Some(&serde_borrow_attr)
221 } else {
222 None
223 }
224 });
225 let output_type = func_list.iter().map(|method| match method.sig.output {
226 ReturnType::Default => &unit_type,
227 ReturnType::Type(_, ref ty) => ty.as_ref(),
228 });
229
230 let req_rsp = quote! {
231 #[derive(#root::Serialize, #root::Deserialize)]
232 #[serde(crate = #serde_path)]
233 #[serde(deny_unknown_fields)]
234 #[allow(non_camel_case_types)]
235 #vis enum #req_ident<'req> {
236 #( #func_ident ( #input_borrow ( #(#input_type,)* ) ), )*
237 ___tiny_rpc_marker((#root::Never, #root::PhantomData<&'req ()>))
238 }
239
240 #[derive(#root::Serialize, #root::Deserialize)]
241 #[serde(crate = #serde_path)]
242 #[serde(deny_unknown_fields)]
243 #[allow(non_camel_case_types)]
244 #vis enum #rsp_ident {
245 #( #func_ident ( #output_type ), )*
246 }
247 };
248
249 (req_rsp, req_ident, rsp_ident)
250}
251
252fn gen_server<'a>(
253 root: &Path,
254 ident: &Ident,
255 req_ident: &Ident,
256 rsp_ident: &Ident,
257 func_list: &[Cow<'a, TraitItemMethod>],
258) -> proc_macro2::TokenStream {
259 let null_stream = quote! {}; let keyword_await = quote! { .await }; let server_ident = format_ident!("{}Server", ident);
263 let func_ident = func_list
264 .iter()
265 .map(|method| &method.sig.ident)
266 .collect::<Vec<_>>();
267 let input_ident = func_list
268 .iter()
269 .map(|method| {
270 method
271 .sig
272 .inputs
273 .iter()
274 .filter_map(|input| match input {
275 FnArg::Receiver(_) => None, FnArg::Typed(PatType { pat, .. }) => match &**pat {
277 Pat::Ident(ident) => Some(&ident.ident),
278 _ => unreachable!(),
279 },
280 })
281 .collect::<Vec<_>>()
282 })
283 .collect::<Vec<_>>();
284 let await_if_async = func_list.iter().map(|method| {
285 method
286 .sig
287 .asyncness
288 .map_or(&null_stream, |_| &keyword_await)
289 });
290
291 quote! {
292 pub struct #server_ident<T: #ident + #root::Send + #root::Sync + 'static>(#root::Arc<T>);
293
294 impl<T: #ident + #root::Send + #root::Sync + 'static> #server_ident<T> {
295 pub fn serve(server_impl: T, transport: #root::Transport) -> #root::BoxStream<'static, #root::BoxFuture<'static, ()>> {
296 Self::__internal_serve(Self(#root::Arc::new(server_impl)), transport)
297 }
298
299 pub fn serve_arc(server_impl: #root::Arc<T>, transport: #root::Transport) -> #root::BoxStream<'static, #root::BoxFuture<'static, ()>> {
300 Self::__internal_serve(Self(server_impl), transport)
301 }
302
303 fn __internal_serve(self, transport: #root::Transport) -> #root::BoxStream<'static, #root::BoxFuture<'static, ()>> {
304 #root::Server::serve(self, transport)
305 }
306 }
307
308 impl<T: #ident + #root::Send + #root::Sync + 'static> #root::Clone for #server_ident<T> {
309 fn clone(&self) -> Self {
310 Self(#root::Clone::clone(&self.0))
311 }
312 }
313
314 impl<T: #ident + #root::Send + #root::Sync + 'static> #root::Server for #server_ident<T> {
315 fn make_response(
316 self,
317 req: #root::RpcFrame,
318 ) -> #root::BoxFuture<'static, #root::Result<#root::RpcFrame>> {
319 #root::FutureExt::boxed(
320 async move {
321 let id = req.id()?;
322 let req = req.data()?;
323 let rsp = match req {
324 #(
325 #req_ident::#func_ident( ( #(#input_ident,)* ) ) => {
326 #rsp_ident::#func_ident(self.0.#func_ident(#(#input_ident),*) #await_if_async)
327 }
328 )*
329 #req_ident::___tiny_rpc_marker(_) => #root::unreachable!(),
330 };
331 let rsp = #root::RpcFrame::new(id, rsp)?;
332 Ok(rsp)
333 }
334 )
335 }
336 }
337 }
338}
339
340fn gen_client<'a>(
341 root: &Path,
342 ident: &Ident,
343 req_ident: &Ident,
344 rsp_ident: &Ident,
345 func_list: &[Cow<'a, TraitItemMethod>],
346) -> proc_macro2::TokenStream {
347 let unit_type: Type = parse_quote!(()); let client_ident = format_ident!("{}Client", ident);
350 let signature = func_list
351 .iter()
352 .cloned()
353 .map(|method| {
354 let method = method.into_owned();
355 let span = method.span();
356 let mut sig = method.sig;
357
358 if sig.asyncness.is_none() {
359 sig.asyncness = Some(Token);
360 }
361
362 let ty = match sig.output {
363 ReturnType::Type(_, ty) => *ty,
364 ReturnType::Default => unit_type.clone(),
365 };
366 sig.output = parse_quote! { -> #root::Result<#ty> };
367 sig
368 })
369 .collect::<Vec<_>>();
370 let arg_ident = signature.iter().map(|sig| {
371 sig.inputs
372 .iter()
373 .filter_map(|arg| match arg {
374 FnArg::Receiver(_) => None,
375 FnArg::Typed(PatType { pat, .. }) => match &**pat {
376 Pat::Ident(ident) => Some(&ident.ident),
377 _ => unreachable!(),
378 },
379 })
380 .collect::<Vec<_>>()
381 });
382 let func_ident = signature.iter().map(|sig| &sig.ident);
383
384 quote! {
385 #[derive(Clone)]
386 pub struct #client_ident(#root::IdGenerator, #root::ClientDriverHandle);
387
388 impl #root::Client for #client_ident {
389 fn from_handle(handle: #root::ClientDriverHandle) -> Self {
390 Self(#root::IdGenerator::new(), handle)
391 }
392
393 fn handle(&self) -> &#root::ClientDriverHandle {
394 &self.1
395 }
396 }
397
398 impl #client_ident {
399 pub fn new(transport: #root::Transport) -> (Self, #root::BoxFuture<'static, ()>) {
400 #root::Client::new(transport)
401 }
402
403 #(
404 pub #signature {
405 let args = ( #(#arg_ident,)* );
406 let id = self.0.next();
407 let span = info_span!(#root::stringify!(#func_ident), %id);
408
409 #root::Instrument::instrument(
410 async move {
411 let req = #req_ident::#func_ident(args);
412 let req = #root::RpcFrame::new(id, req)?;
413 let rsp = <Self as #root::Client>::make_request(self, req).await?;
414 let rsp = rsp.data()?;
415 match rsp {
416 #rsp_ident::#func_ident(ret) => Ok(ret),
417 _ => Err(#root::Into::into(#root::ProtocolError::ResponseMismatch(id))),
418 }
419 },
420 span,
421 )
422 .await
423 }
424 )*
425 }
426 }
427}
428
429#[test]
430fn test_is_ref_receiver() {
431 let ref_receiver: &[FnArg] = &[
432 parse_quote!(self),
433 parse_quote!(&self),
434 parse_quote!(&'a self),
435 parse_quote!(&mut self),
436 parse_quote!(&'a mut self),
437 parse_quote!(self: Self),
438 parse_quote!(self: &Self),
439 parse_quote!(self: &'a Self),
440 parse_quote!(self: &mut Self),
441 parse_quote!(self: &'a mut Self),
442 ];
443 let answer = &[
444 false, true, true, false, false, false, true, true, false, false,
445 ];
446
447 assert_eq!(is_ref_receiver(None), false);
448 for (t, a) in ref_receiver.into_iter().zip(answer) {
449 assert_eq!(is_ref_receiver(Some(t)), *a);
450 }
451}