1#![deny(missing_docs)]
2use heck::ToSnakeCase;
7use proc_macro::TokenStream;
8use proc_macro2::TokenStream as TokenStream2;
9use quote::{format_ident, quote};
10use syn::{Field, Fields, FieldsNamed, Ident, ItemEnum, Type};
11
12struct MacroArgs {
13 actor: Ident,
14}
15
16impl syn::parse::Parse for MacroArgs {
17 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
18 Ok(MacroArgs {
19 actor: input.parse()?,
20 })
21 }
22}
23
24fn parse_call_attr(attr: &syn::Attribute) -> syn::Result<Option<Type>> {
27 match &attr.meta {
28 syn::Meta::Path(_) => Ok(None),
29 syn::Meta::List(list) if list.tokens.is_empty() => Ok(None),
30 syn::Meta::List(list) => syn::parse2::<Type>(list.tokens.clone()).map(Some),
31 syn::Meta::NameValue(nv) => Err(syn::Error::new_spanned(
32 nv,
33 "#[call] does not support key=value syntax; use #[call] or #[call(ReturnType)]",
34 )),
35 }
36}
37
38#[proc_macro_attribute]
76pub fn message(attr: TokenStream, item: TokenStream) -> TokenStream {
77 message_impl(attr, item)
78 .unwrap_or_else(|e| e.to_compile_error())
79 .into()
80}
81
82fn message_impl(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream2> {
83 let MacroArgs { actor } = syn::parse(attr)?;
84 let mut enum_def: ItemEnum = syn::parse(item)?;
85
86 let enum_vis = enum_def.vis.clone();
87 let enum_name = enum_def.ident.clone();
88 let trait_name = format_ident!("{}Ext", actor);
89
90 let mut trait_methods: Vec<TokenStream2> = Vec::new();
91 let mut impl_methods: Vec<TokenStream2> = Vec::new();
92
93 for variant in &mut enum_def.variants {
94 let call_pos = variant.attrs.iter().position(|a| a.path().is_ident("call"));
96 let ret_ty: Option<Type> = call_pos
97 .map(|i| {
98 let attr = variant.attrs.remove(i);
99 parse_call_attr(&attr).map(|t| t.unwrap_or_else(|| syn::parse_quote!(())))
100 })
101 .transpose()?;
102
103 let variant_name = &variant.ident;
104 let method_name = format_ident!("{}", variant_name.to_string().to_snake_case());
105
106 let is_tuple = matches!(variant.fields, Fields::Unnamed(_));
108 let (param_idents, param_types): (Vec<Ident>, Vec<Type>) = match &variant.fields {
109 Fields::Named(f) => f
110 .named
111 .iter()
112 .map(|field| (field.ident.clone().unwrap(), field.ty.clone()))
113 .unzip(),
114 Fields::Unnamed(f) => f
115 .unnamed
116 .iter()
117 .enumerate()
118 .map(|(i, field)| (format_ident!("arg{}", i), field.ty.clone()))
119 .unzip(),
120 Fields::Unit => (vec![], vec![]),
121 };
122
123 let ret = ret_ty.as_ref().map_or_else(|| quote!(()), |t| quote!(#t));
124 let sig = quote! {
125 async fn #method_name(&self #(, #param_idents: #param_types)*)
126 -> ::core::result::Result<#ret, ::stagecraft::ActorDead<()>>
127 };
128
129 let body: TokenStream2 = if let Some(ref ret_ty) = ret_ty {
130 if is_tuple {
131 let respond_to_field = Field {
133 attrs: vec![],
134 vis: syn::Visibility::Inherited,
135 mutability: syn::FieldMutability::None,
136 ident: None,
137 colon_token: None,
138 ty: syn::parse_quote! { ::tokio::sync::oneshot::Sender<#ret_ty> },
139 };
140 match &mut variant.fields {
141 Fields::Unnamed(f) => f.unnamed.push(respond_to_field),
142 _ => unreachable!(),
143 }
144 quote! {
145 self.call(|tx| #enum_name::#variant_name(#(#param_idents,)* tx)).await
146 }
147 } else {
148 let respond_to: Field = syn::parse_quote! {
149 respond_to: ::tokio::sync::oneshot::Sender<#ret_ty>
150 };
151 match &mut variant.fields {
152 Fields::Named(f) => f.named.push(respond_to),
153 Fields::Unit => {
154 variant.fields = Fields::Named(FieldsNamed {
155 brace_token: Default::default(),
156 named: std::iter::once(respond_to).collect(),
157 });
158 }
159 Fields::Unnamed(_) => unreachable!(),
160 }
161 quote! {
162 self.call(|tx| #enum_name::#variant_name { #(#param_idents,)* respond_to: tx }).await
163 }
164 }
165 } else {
166 let construction = if is_tuple {
167 quote! { #enum_name::#variant_name(#(#param_idents),*) }
168 } else if param_idents.is_empty() {
169 quote! { #enum_name::#variant_name }
170 } else {
171 quote! { #enum_name::#variant_name { #(#param_idents),* } }
172 };
173 quote! {
174 self.cast(#construction).await.map_err(|_| ::stagecraft::ActorDead(()))
175 }
176 };
177
178 trait_methods.push(quote! { #sig; });
179 impl_methods.push(quote! { #sig { #body } });
180 }
181
182 Ok(quote! {
183 #enum_def
184
185 #enum_vis trait #trait_name {
186 #(#trait_methods)*
187 }
188
189 impl #trait_name for ::stagecraft::Handle<#actor> {
190 #(#impl_methods)*
191 }
192 })
193}