xy_rpc_macro/lib.rs
1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use proc_macro2::{Ident, Span};
4use quote::{format_ident, quote};
5use std::iter::once;
6use syn::punctuated::Punctuated;
7use syn::{
8 parse_macro_input, parse_quote, Field, FieldMutability, Fields, FieldsNamed, FieldsUnnamed,
9 FnArg, GenericArgument, GenericParam, Generics, ItemEnum, ItemTrait, Lifetime, LifetimeParam,
10 PathArguments, ReturnType, Token, TraitItem, Type, TypeParamBound, TypeReference, Variant,
11 Visibility,
12};
13
14#[proc_macro_attribute]
15pub fn rpc_service(_attr: TokenStream, item: TokenStream) -> TokenStream {
16 let mut ast = parse_macro_input!(item as ItemTrait);
17 let vis = ast.vis.clone();
18 {
19 for item in ast.items.iter_mut() {
20 if let TraitItem::Fn(f) = item {
21 if let Some(_r) = f.sig.asyncness.take() {
22 f.sig.output = match &f.sig.output {
23 ReturnType::Default => parse_quote! {
24 -> impl core::future::Future<Output = ()> + xy_rpc::maybe_send::MaybeSend
25 },
26 ReturnType::Type(_, ty) => parse_quote! {
27 -> impl core::future::Future<Output = #ty> + xy_rpc::maybe_send::MaybeSend
28 },
29 };
30 }
31 }
32 }
33 }
34 let trait_ident = &ast.ident;
35 let msg_enum_ident = format_ident!("{}Msg", ast.ident);
36 let msg_ref_enum_ident = format_ident!("{}RefMsg", ast.ident);
37 let msg_reply_enum_ident = format_ident!("{}ReplyMsg", ast.ident);
38 // let msg_reply_ref_enum_ident = format_ident!("{}ReplyRefMsg", ast.ident);
39 let _handler_ident = format_ident!("{}Handler", ast.ident);
40 let schema_ident = format_ident!("{}Schema", ast.ident);
41 let caller_ident = format_ident!("{}Caller", ast.ident);
42 let ref_lifetime = Lifetime::new("'a", Span::call_site());
43 let (variants, ref_variants, reply_variants) = ast
44 .items
45 .iter()
46 .filter_map(|n| match n {
47 syn::TraitItem::Fn(n) => Some(n),
48 _ => return None,
49 })
50 .map(|n| {
51 let no_async = n.sig.asyncness.is_none();
52 // let found_input_trans_stream = get_input_trans_stream(n);
53 let ident = format_ident!("{}", n.sig.ident.to_string().to_case(Case::UpperCamel));
54 // match found_input_trans_stream {
55 // None => {
56 let (named_fields, named_ref_fields) = n
57 .sig
58 .inputs
59 .iter()
60 .filter_map(|n| match n {
61 FnArg::Typed(n) => Some(n),
62 _ => None,
63 })
64 .map(|n| {
65 (
66 Field {
67 attrs: vec![],
68 vis: Visibility::Inherited,
69 mutability: FieldMutability::None,
70 ident: match n.pat.as_ref() {
71 syn::Pat::Ident(n) => Some(n.ident.clone()),
72 _ => panic!("fm parameters pat: only support named fields"),
73 },
74 colon_token: Default::default(),
75 ty: *n.ty.clone(),
76 },
77 Field {
78 attrs: vec![],
79 vis: Visibility::Inherited,
80 mutability: FieldMutability::None,
81 ident: match n.pat.as_ref() {
82 syn::Pat::Ident(n) => Some(n.ident.clone()),
83 _ => panic!("fm parameters pat: only support named fields"),
84 },
85 colon_token: Default::default(),
86 ty: Type::Reference(TypeReference {
87 and_token: Default::default(),
88 lifetime: Some(ref_lifetime.clone()),
89 mutability: None,
90 elem: Box::new(*n.ty.clone()),
91 }),
92 },
93 )
94 })
95 .collect();
96 let base = Variant {
97 attrs: vec![],
98 ident: ident.clone(),
99 fields: Fields::Named(FieldsNamed {
100 brace_token: Default::default(),
101 named: named_fields,
102 }),
103 discriminant: None,
104 };
105 let base_ref = Variant {
106 attrs: vec![],
107 ident: ident.clone(),
108 fields: Fields::Named(FieldsNamed {
109 brace_token: Default::default(),
110 named: named_ref_fields,
111 }),
112 discriminant: None,
113 };
114 let base_reply = Variant {
115 attrs: vec![],
116 ident: ident.clone(),
117 fields: Fields::Unnamed({
118 let ty = match &n.sig.output {
119 ReturnType::Default => parse_quote!(()),
120 ReturnType::Type(_, ty) => *ty.clone(),
121 };
122 let ty = get_future_output(no_async, &ty);
123 FieldsUnnamed {
124 paren_token: Default::default(),
125 unnamed: once(Field {
126 attrs: vec![],
127 vis: Visibility::Inherited,
128 mutability: FieldMutability::None,
129 ident: None,
130 colon_token: None,
131 ty,
132 })
133 .collect(),
134 }
135 }),
136 discriminant: None,
137 };
138 (base, base_ref, base_reply)
139 // }
140 // Some((_found_input_trans_stream, found_input_trans_stream_index)) => {
141 // let start_variant = Variant {
142 // attrs: vec![],
143 // ident: ident.clone(),
144 // fields: Fields::Named(FieldsNamed {
145 // brace_token: Default::default(),
146 // named: n
147 // .sig
148 // .inputs
149 // .iter()
150 // .enumerate()
151 // .filter(|n| n.0 != found_input_trans_stream_index)
152 // .filter_map(|(_, n)| match n {
153 // FnArg::Typed(n) => Some(n),
154 // _ => None,
155 // })
156 // .map(|n| Field {
157 // attrs: vec![],
158 // vis: Visibility::Inherited,
159 // mutability: FieldMutability::None,
160 // ident: match n.pat.as_ref() {
161 // syn::Pat::Ident(n) => Some(n.ident.clone()),
162 // _ => panic!("fm parameters pat: only support named fields"),
163 // },
164 // colon_token: Default::default(),
165 // ty: *n.ty.clone(),
166 // })
167 // .collect(),
168 // }),
169 // discriminant: None,
170 // };
171 // let base_reply = Variant {
172 // attrs: vec![],
173 // ident: ident.clone(),
174 // fields: Fields::Unnamed({
175 // let ty = match &n.sig.output {
176 // ReturnType::Default => parse_quote!(()),
177 // ReturnType::Type(_, ty) => *ty.clone(),
178 // };
179 // let ty = if no_async {
180 // let future_output_type = match &ty {
181 // Type::ImplTrait(type_impl) => {
182 // type_impl.bounds.iter().find_map(|n| match n {
183 // TypeParamBound::Trait(t) => {
184 // let x = t
185 // .path
186 // .segments
187 // .iter()
188 // .find(|n| n.ident == "Future");
189 // if let Some(x) = x {
190 // let PathArguments::AngleBracketed(args) =
191 // &x.arguments
192 // else {
193 // panic!("invalid return type")
194 // };
195 // args.args.iter().find_map(|n| match n {
196 // GenericArgument::AssocType(a) => {
197 // if a.ident == "Output" {
198 // Some(a.ty.clone())
199 // } else {
200 // None
201 // }
202 // }
203 // _ => None,
204 // })
205 // } else {
206 // None
207 // }
208 // }
209 // _ => None,
210 // })
211 // }
212 // _ => None,
213 // };
214 // if let Some(rt) = future_output_type {
215 // rt
216 // } else {
217 // parse_quote! {
218 // <#ty as core::future::Future>::Output
219 // }
220 // }
221 // } else {
222 // ty
223 // };
224 // FieldsUnnamed {
225 // paren_token: Default::default(),
226 // unnamed: once(Field {
227 // attrs: vec![],
228 // vis: Visibility::Inherited,
229 // mutability: FieldMutability::None,
230 // ident: None,
231 // colon_token: None,
232 // ty,
233 // })
234 // .collect(),
235 // }
236 // }),
237 // discriminant: None,
238 // };
239 // Either::Right(vec![(start_variant, base_reply)].into_iter())
240 // }
241 // }
242 })
243 .collect();
244 let msg_enum = ItemEnum {
245 attrs: parse_quote!(#[derive(Debug,serde::Serialize, serde::Deserialize)]),
246 vis: vis.clone(),
247 enum_token: Default::default(),
248 ident: msg_enum_ident.clone(),
249 generics: Default::default(),
250 brace_token: Default::default(),
251 variants,
252 };
253 let msg_ref_enum = ItemEnum {
254 attrs: parse_quote!(#[derive(Debug,serde::Serialize)]),
255 vis: vis.clone(),
256 enum_token: Default::default(),
257 ident: msg_ref_enum_ident.clone(),
258 generics: Generics {
259 lt_token: None,
260 params: once(GenericParam::Lifetime(LifetimeParam {
261 attrs: vec![],
262 lifetime: ref_lifetime.clone(),
263 colon_token: None,
264 bounds: Default::default(),
265 }))
266 .collect(),
267 gt_token: None,
268 where_clause: None,
269 },
270 brace_token: Default::default(),
271 variants: ref_variants,
272 };
273 let msg_reply_enum = ItemEnum {
274 attrs: parse_quote!(#[derive(Debug,serde::Serialize, serde::Deserialize)]),
275 vis: vis.clone(),
276 enum_token: Default::default(),
277 ident: msg_reply_enum_ident.clone(),
278 generics: Default::default(),
279 brace_token: Default::default(),
280 variants: reply_variants,
281 };
282
283 let (rpc_call_fn, rpc_call_fn_impl, match_expr, msg_info_matches, msg_reply_info_matches, ref_msg_info_matches): (Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>) = ast
284 .items
285 .iter()
286 .filter_map(|n| match n {
287 syn::TraitItem::Fn(n) => Some(n),
288 _ => return None,
289 })
290 .enumerate()
291 .map(|(i, n)| {
292 let id = i + 1;
293 let fields: Punctuated<Ident, Token![,]> = n
294 .sig
295 .inputs
296 .iter()
297 .filter_map(|n| match n {
298 FnArg::Typed(n) => Some(match n.pat.as_ref() {
299 syn::Pat::Ident(n) => n.ident.clone(),
300 _ => panic!("only support named fields"),
301 }),
302 _ => None,
303 })
304 .collect();
305 let mut rpc_call_fn = n.sig.clone();
306 for arg in rpc_call_fn.inputs.iter_mut() {
307 if let FnArg::Typed(arg) = arg {
308 arg.ty = Box::new(Type::Reference(TypeReference {
309 and_token: Default::default(),
310 lifetime: Some(ref_lifetime.clone()),
311 mutability: None,
312 elem: Box::new(*arg.ty.clone()),
313 }));
314 };
315 }
316 rpc_call_fn.generics = Generics {
317 lt_token: Some(Default::default()),
318 params: once(GenericParam::Lifetime(LifetimeParam {
319 attrs: vec![],
320 lifetime: ref_lifetime.clone(),
321 colon_token: None,
322 bounds: Default::default(),
323 })).collect(),
324 gt_token: Some(Default::default()),
325 where_clause: None,
326 };
327 match rpc_call_fn.output {
328 ReturnType::Default => {
329 rpc_call_fn = parse_quote! {
330 -> impl core::future::Future<Output = Result<(), xy_rpc::RpcError>> + xy_rpc::maybe_send::MaybeSend +'static
331 }
332 }
333 ReturnType::Type(_, ty) => {
334 let is_async = rpc_call_fn.asyncness.is_some();
335 let output_ty = get_future_output(!is_async, &*ty);
336 if is_async {
337 rpc_call_fn.asyncness = None;
338 }
339 rpc_call_fn.output = parse_quote! {
340 -> impl core::future::Future<Output = Result<#output_ty, xy_rpc::RpcError>> + xy_rpc::maybe_send::MaybeSend +'static
341 };
342 }
343 }
344 let fn_ident = &rpc_call_fn.ident;
345 let item_name = fn_ident.to_string().to_case(Case::UpperCamel);
346 let enum_item_ident = format_ident!("{}", item_name);
347 (
348 quote! {
349 #rpc_call_fn
350 },
351 quote! {
352 #rpc_call_fn {
353 let future = self.call(#msg_ref_enum_ident::#enum_item_ident { #fields });
354 async move {
355 let reply = future.await?;
356 let #msg_reply_enum_ident::#enum_item_ident(reply) = reply.msg else {
357 return Err(xy_rpc::RpcError::InvalidMsg)
358 };
359 Ok(reply)
360 }
361 }
362 },
363 quote! {
364 #msg_enum_ident::#enum_item_ident { #fields } => {
365 let r = self.service.#fn_ident(#fields).await;
366 #msg_reply_enum_ident::#enum_item_ident(r)
367 }
368 },
369 quote! {
370 #msg_enum_ident::#enum_item_ident { .. } => {
371 xy_rpc::RpcMsgInfo {
372 id: #id as _,
373 name: #item_name
374 }
375 }
376 },
377 quote! {
378 #msg_reply_enum_ident::#enum_item_ident(_) => {
379 xy_rpc::RpcMsgInfo {
380 id: #id as _,
381 name: #item_name
382 }
383 }
384 },
385 quote! {
386 #msg_ref_enum_ident::#enum_item_ident { .. } => {
387 xy_rpc::RpcMsgInfo {
388 id: #id as _,
389 name: #item_name
390 }
391 }
392 },
393 )
394 })
395 .collect();
396
397 let schema = quote! {
398 #[derive(Clone, Debug, Default)]
399 #vis struct #schema_ident;
400 impl xy_rpc::RpcServiceSchema for #schema_ident
401 {
402 type Msg = #msg_enum_ident;
403 type Reply = #msg_reply_enum_ident;
404 }
405 };
406
407 let impls = quote! {
408 impl<T> xy_rpc::RpcMsgHandler<#schema_ident> for xy_rpc::RpcMsgHandlerWrapper<T>
409 where
410 T: #trait_ident,
411 {
412 fn handle(
413 &self,
414 msg: #msg_enum_ident,
415 ) -> impl core::future::Future<Output = #msg_reply_enum_ident> + xy_rpc::maybe_send::MaybeSend {
416 async move {
417 match msg {
418 #(#match_expr)*
419 }
420 }
421 }
422 }
423 impl<'a> xy_rpc::RpcRefMsg for #msg_ref_enum_ident<'a> {
424 fn info(&self) -> xy_rpc::RpcMsgInfo {
425 match self {
426 #(#ref_msg_info_matches)*
427 }
428 }
429 }
430 impl xy_rpc::RpcRefMsg for #msg_enum_ident {
431 fn info(&self) -> xy_rpc::RpcMsgInfo {
432 match self {
433 #(#msg_info_matches)*
434 }
435 }
436 }
437 impl xy_rpc::RpcMsg for #msg_enum_ident {
438 type Ref<'a> = #msg_ref_enum_ident<'a>;
439 }
440 impl<'a> xy_rpc::RpcRefMsg for &'a #msg_reply_enum_ident {
441 fn info(&self) -> xy_rpc::RpcMsgInfo {
442 match self {
443 #(#msg_reply_info_matches)*
444 }
445 }
446 }
447 impl xy_rpc::RpcRefMsg for #msg_reply_enum_ident {
448 fn info(&self) -> xy_rpc::RpcMsgInfo {
449 match self {
450 #(#msg_reply_info_matches)*
451 }
452 }
453 }
454 impl xy_rpc::RpcMsg for #msg_reply_enum_ident {
455 type Ref<'a> = &'a #msg_reply_enum_ident;
456 }
457 #vis trait #caller_ident {
458 #(#rpc_call_fn;)*
459 }
460 impl<CF> #caller_ident for xy_rpc::XyRpcChannel<CF,#schema_ident> where CF: xy_rpc::formats::SerdeFormat {
461 #(#rpc_call_fn_impl)*
462 }
463 };
464
465 quote! {
466 #ast
467 #schema
468 #msg_enum
469 #msg_ref_enum
470 #msg_reply_enum
471 #impls
472 }
473 .into()
474}
475
476fn get_future_output(no_async: bool, ty: &Type) -> Type {
477 if no_async {
478 let future_output_type = match &ty {
479 Type::ImplTrait(type_impl) => type_impl.bounds.iter().find_map(|n| match n {
480 TypeParamBound::Trait(t) => {
481 let x = t.path.segments.iter().find(|n| n.ident == "Future");
482 if let Some(x) = x {
483 let PathArguments::AngleBracketed(args) = &x.arguments else {
484 panic!("invalid return type")
485 };
486 args.args.iter().find_map(|n| match n {
487 GenericArgument::AssocType(a) => {
488 if a.ident == "Output" {
489 Some(a.ty.clone())
490 } else {
491 None
492 }
493 }
494 _ => None,
495 })
496 } else {
497 None
498 }
499 }
500 _ => None,
501 }),
502 _ => None,
503 };
504 if let Some(rt) = future_output_type {
505 rt
506 } else {
507 parse_quote! {
508 <#ty as core::future::Future>::Output
509 }
510 }
511 } else {
512 ty.clone()
513 }
514}
515/*
516fn get_input_trans_stream(n: &TraitItemFn) -> Option<(&PathSegment, usize)> {
517 n.sig.inputs.iter().enumerate().find_map(|(i, n)| {
518 let ty = match n {
519 FnArg::Typed(n) => n,
520 _ => return None,
521 };
522 let Type::Path(ty) = &*ty.ty else { return None };
523 let segment = ty.path.segments.last()?;
524 (segment.ident == "TransStream").then_some((segment, i))
525 })
526}
527*/