Skip to main content

stagecraft_macros/
lib.rs

1#![deny(missing_docs)]
2//! Proc-macro crate for stagecraft. Use via the re-exported [`stagecraft::message`] attribute.
3//!
4//! [`stagecraft::message`]: https://docs.rs/stagecraft/latest/stagecraft/attr.message.html
5
6use heck::ToSnakeCase;
7use proc_macro::TokenStream;
8use proc_macro2::TokenStream as TokenStream2;
9use quote::{format_ident, quote};
10use syn::{Field, Fields, FieldsNamed, Ident, ItemEnum, Type};
11
12struct MacroArgs {
13    actor: Ident,
14}
15
16impl syn::parse::Parse for MacroArgs {
17    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
18        Ok(MacroArgs {
19            actor: input.parse()?,
20        })
21    }
22}
23
24/// Returns `None` for `#[call]` or `#[call()]` (→ return `()`),
25/// or `Some(ty)` for `#[call(SomeType)]`.
26fn parse_call_attr(attr: &syn::Attribute) -> syn::Result<Option<Type>> {
27    match &attr.meta {
28        syn::Meta::Path(_) => Ok(None),
29        syn::Meta::List(list) if list.tokens.is_empty() => Ok(None),
30        syn::Meta::List(list) => syn::parse2::<Type>(list.tokens.clone()).map(Some),
31        syn::Meta::NameValue(nv) => Err(syn::Error::new_spanned(
32            nv,
33            "#[call] does not support key=value syntax; use #[call] or #[call(ReturnType)]",
34        )),
35    }
36}
37
38/// Derive actor message convenience methods on a [`Handle`].
39///
40/// Applied to a message enum, this attribute generates an extension trait `{Actor}Ext`
41/// and implements it for `Handle<Actor>`, giving each enum variant a corresponding async method.
42///
43/// # Variant Attributes
44///
45/// - *(none)* — variant sent with [`Handle::cast`] (fire-and-forget); method returns `Result<(), ActorDead<()>>`.
46/// - `#[call]` — variant uses [`Handle::call`]; method returns `Result<(), ActorDead<()>>`.
47/// - `#[call(ReturnType)]` — variant uses [`Handle::call`]; method returns `Result<ReturnType, ActorDead<()>>`.
48///
49/// For `#[call]` variants the macro appends a `respond_to: oneshot::Sender<ReturnType>`
50/// field to the variant. The actor must send a value through it.
51///
52/// # Example
53///
54/// ```rust,ignore
55/// struct MyActor;
56///
57/// #[stagecraft::message(MyActor)]
58/// pub enum MyActorMessage {
59///     Log { text: String },   // cast: fire-and-forget
60///     #[call(u64)]
61///     Count,                   // call: returns u64
62/// }
63///
64/// // The macro generates:
65/// // pub trait MyActorExt {
66/// //     async fn log(&self, text: String) -> Result<(), ActorDead<()>>;
67/// //     async fn count(&self) -> Result<u64, ActorDead<()>>;
68/// // }
69/// // impl MyActorExt for Handle<MyActor> { ... }
70/// ```
71///
72/// [`Handle`]: stagecraft::Handle
73/// [`Handle::cast`]: stagecraft::Handle::cast
74/// [`Handle::call`]: stagecraft::Handle::call
75#[proc_macro_attribute]
76pub fn message(attr: TokenStream, item: TokenStream) -> TokenStream {
77    message_impl(attr, item)
78        .unwrap_or_else(|e| e.to_compile_error())
79        .into()
80}
81
82fn message_impl(attr: TokenStream, item: TokenStream) -> syn::Result<TokenStream2> {
83    let MacroArgs { actor } = syn::parse(attr)?;
84    let mut enum_def: ItemEnum = syn::parse(item)?;
85
86    let enum_vis = enum_def.vis.clone();
87    let enum_name = enum_def.ident.clone();
88    let trait_name = format_ident!("{}Ext", actor);
89
90    let mut trait_methods: Vec<TokenStream2> = Vec::new();
91    let mut impl_methods: Vec<TokenStream2> = Vec::new();
92
93    for variant in &mut enum_def.variants {
94        // Find and remove #[call] attr
95        let call_pos = variant.attrs.iter().position(|a| a.path().is_ident("call"));
96        let ret_ty: Option<Type> = call_pos
97            .map(|i| {
98                let attr = variant.attrs.remove(i);
99                parse_call_attr(&attr).map(|t| t.unwrap_or_else(|| syn::parse_quote!(())))
100            })
101            .transpose()?;
102
103        let variant_name = &variant.ident;
104        let method_name = format_ident!("{}", variant_name.to_string().to_snake_case());
105
106        // Collect param names and types; generate synthetic names for tuple fields.
107        let is_tuple = matches!(variant.fields, Fields::Unnamed(_));
108        let (param_idents, param_types): (Vec<Ident>, Vec<Type>) = match &variant.fields {
109            Fields::Named(f) => f
110                .named
111                .iter()
112                .map(|field| (field.ident.clone().unwrap(), field.ty.clone()))
113                .unzip(),
114            Fields::Unnamed(f) => f
115                .unnamed
116                .iter()
117                .enumerate()
118                .map(|(i, field)| (format_ident!("arg{}", i), field.ty.clone()))
119                .unzip(),
120            Fields::Unit => (vec![], vec![]),
121        };
122
123        let ret = ret_ty.as_ref().map_or_else(|| quote!(()), |t| quote!(#t));
124        let sig = quote! {
125            async fn #method_name(&self #(, #param_idents: #param_types)*)
126                -> ::core::result::Result<#ret, ::stagecraft::ActorDead<()>>
127        };
128
129        let body: TokenStream2 = if let Some(ref ret_ty) = ret_ty {
130            if is_tuple {
131                // Append respond_to as the last tuple element.
132                let respond_to_field = Field {
133                    attrs: vec![],
134                    vis: syn::Visibility::Inherited,
135                    mutability: syn::FieldMutability::None,
136                    ident: None,
137                    colon_token: None,
138                    ty: syn::parse_quote! { ::tokio::sync::oneshot::Sender<#ret_ty> },
139                };
140                match &mut variant.fields {
141                    Fields::Unnamed(f) => f.unnamed.push(respond_to_field),
142                    _ => unreachable!(),
143                }
144                quote! {
145                    self.call(|tx| #enum_name::#variant_name(#(#param_idents,)* tx)).await
146                }
147            } else {
148                let respond_to: Field = syn::parse_quote! {
149                    respond_to: ::tokio::sync::oneshot::Sender<#ret_ty>
150                };
151                match &mut variant.fields {
152                    Fields::Named(f) => f.named.push(respond_to),
153                    Fields::Unit => {
154                        variant.fields = Fields::Named(FieldsNamed {
155                            brace_token: Default::default(),
156                            named: std::iter::once(respond_to).collect(),
157                        });
158                    }
159                    Fields::Unnamed(_) => unreachable!(),
160                }
161                quote! {
162                    self.call(|tx| #enum_name::#variant_name { #(#param_idents,)* respond_to: tx }).await
163                }
164            }
165        } else {
166            let construction = if is_tuple {
167                quote! { #enum_name::#variant_name(#(#param_idents),*) }
168            } else if param_idents.is_empty() {
169                quote! { #enum_name::#variant_name }
170            } else {
171                quote! { #enum_name::#variant_name { #(#param_idents),* } }
172            };
173            quote! {
174                self.cast(#construction).await.map_err(|_| ::stagecraft::ActorDead(()))
175            }
176        };
177
178        trait_methods.push(quote! { #sig; });
179        impl_methods.push(quote! { #sig { #body } });
180    }
181
182    Ok(quote! {
183        #enum_def
184
185        #enum_vis trait #trait_name {
186            #(#trait_methods)*
187        }
188
189        impl #trait_name for ::stagecraft::Handle<#actor> {
190            #(#impl_methods)*
191        }
192    })
193}