Skip to main content

sactor_macros/
lib.rs

1use manyhow::manyhow;
2use proc_macro2::TokenStream;
3use quote::quote;
4use syn::{
5    Error, FnArg, GenericParam, Ident, ImplItem, ImplItemFn, ItemImpl, Pat, PatIdent, Result,
6    ReturnType, Type, Visibility, parse2, spanned::Spanned,
7};
8
9#[manyhow]
10#[proc_macro_attribute]
11pub fn sactor(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
12    let handle_vis: Visibility = if attr.is_empty() {
13        Visibility::Inherited
14    } else {
15        parse2(attr)?
16    };
17    let mut input: ItemImpl = parse2(item)?;
18    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
19
20    let self_ident = {
21        let Type::Path(path) = input.self_ty.as_ref() else {
22            return Err(Error::new_spanned(&input.self_ty, "expected a path"));
23        };
24        path.path.segments.last().unwrap().ident.clone()
25    };
26    let handle_ident = Ident::new(&format!("{}Handle", self_ident), self_ident.span());
27    let events_ident = Ident::new(&format!("{}Events", self_ident), self_ident.span());
28
29    let type_params: Vec<_> = input
30        .generics
31        .params
32        .iter()
33        .filter_map(|p| {
34            if let GenericParam::Type(tp) = p {
35                Some(&tp.ident)
36            } else {
37                None
38            }
39        })
40        .collect();
41
42    let mut event_variants = Vec::new();
43    let mut handle_items = Vec::new();
44    let mut run_arms = Vec::new();
45    let mut sel = None; // select ident and asyncness
46    let mut error_handler = None;
47    for item in &mut input.items {
48        let ImplItem::Fn(ImplItemFn {
49            attrs, vis, sig, ..
50        }) = item
51        else {
52            continue;
53        };
54        if sig.inputs.is_empty() {
55            continue;
56        }
57        match sig.inputs.first().unwrap() {
58            FnArg::Typed(_) => continue,
59            FnArg::Receiver(receiver) if receiver.reference.is_none() => continue,
60            _ => {}
61        }
62
63        let mut skip = false;
64        let mut reply = None;
65        let mut select = false;
66        let mut error = false;
67        attrs.retain(|attr| {
68            let path = attr.meta.path();
69            if path.is_ident("skip") {
70                skip = true;
71                return false;
72            }
73            if path.is_ident("reply") {
74                reply = Some(true);
75                return false;
76            }
77            if path.is_ident("no_reply") {
78                reply = Some(false);
79                return false;
80            }
81            if path.is_ident("select") {
82                select = true;
83                return false;
84            }
85            if path.is_ident("handle_error") {
86                error = true;
87                return false;
88            }
89            true
90        });
91        if select {
92            if sel.is_some() {
93                return Err(Error::new_spanned(
94                    &sig.ident,
95                    "multiple select methods are not allowed",
96                ));
97            }
98            sel = Some((sig.ident.clone(), sig.asyncness.is_some()));
99            continue;
100        }
101        if error {
102            if error_handler.is_some() {
103                return Err(Error::new_spanned(
104                    &sig.ident,
105                    "multiple error handler methods are not allowed",
106                ));
107            }
108            error_handler = Some((sig.ident.clone(), sig.asyncness.is_some()));
109            continue;
110        }
111        if skip {
112            continue;
113        }
114
115        // reject method-level generics
116        if !sig.generics.params.is_empty() {
117            return Err(Error::new_spanned(
118                &sig.generics,
119                "should not have method-level generics",
120            ));
121        }
122
123        // output type
124        let mut handle_error = false;
125        let output = match &sig.output {
126            ReturnType::Default => quote! { () },
127            ReturnType::Type(_, ty) => {
128                if reply.is_none() {
129                    reply = Some(true);
130                }
131                if let Type::Path(path) = ty.as_ref() {
132                    let Some(last) = path.path.segments.last() else {
133                        return Err(Error::new_spanned(
134                            &path.path,
135                            "expected a path with segments",
136                        ));
137                    };
138                    if last.ident == "Result" {
139                        handle_error = true;
140                    }
141                }
142                if let Some(false) = reply {
143                    quote! { () }
144                } else {
145                    quote! { #ty }
146                }
147            }
148        };
149        let mut handle_sig = sig.clone();
150        handle_sig.asyncness = Some(parse2(quote! { async })?);
151        handle_sig.output = parse2(quote! { -> anyhow::Result<#output> })?;
152
153        // input args
154        let mut arg_types = Vec::new();
155        let mut arg_names = Vec::new();
156        for (i, arg) in &mut handle_sig.inputs.iter_mut().enumerate() {
157            let arg = match arg {
158                FnArg::Typed(arg) => arg,
159                FnArg::Receiver(arg) => {
160                    arg.mutability = None;
161                    let Type::Reference(reference) = arg.ty.as_mut() else {
162                        return Err(Error::new_spanned(&arg.ty, "expected a reference"));
163                    };
164                    reference.mutability = None;
165                    continue;
166                }
167            };
168            arg_types.push(arg.ty.clone());
169            let arg_name = format!("arg{}", i);
170            arg_names.push(Ident::new(&arg_name, arg.pat.span()));
171            *arg.pat = Pat::Ident(PatIdent {
172                attrs: Vec::new(),
173                by_ref: None,
174                mutability: None,
175                ident: Ident::new(&arg_name, arg.pat.span()),
176                subpat: None,
177            });
178        }
179
180        // event type and args
181        let event_name = &sig.ident;
182        let arg_typle_type = quote! { (#(#arg_types),*) };
183        let arg_tuple = quote! { (#(#arg_names),*) };
184
185        let f = if reply.unwrap_or(false) {
186            quote! {
187                #vis #handle_sig {
188                    let (tx, rx) = futures::channel::oneshot::channel();
189                    self.0.unbounded_send(#events_ident::#event_name(#arg_tuple, tx))
190                        .map_err(|_| sactor::error::SactorError::ActorStopped)?;
191                    #[allow(clippy::needless_question_mark)]
192                    Ok(rx.await.map_err(|_| sactor::error::SactorError::ActorStopped)?)
193                }
194            }
195        } else {
196            quote! {
197                #vis #handle_sig {
198                    self.0.unbounded_send(#events_ident::#event_name(#arg_tuple))
199                        .map_err(|_| sactor::error::SactorError::ActorStopped)?;
200                    Ok(())
201                }
202            }
203        };
204
205        handle_items.push(f);
206
207        let aw = match sig.asyncness {
208            None => quote! {},
209            Some(_) => quote! { .await },
210        };
211        let handle_error = match handle_error {
212            false => quote! {},
213            true => quote! {
214                if let Err(e) = &mut result {
215                    actor.__sactor_handle_error(e).await;
216                }
217            },
218        };
219        if reply.unwrap_or(false) {
220            event_variants.push(
221                quote! { #event_name(#arg_typle_type, futures::channel::oneshot::Sender<#output>) },
222            );
223            run_arms.push(quote! {
224                Ok(#events_ident::#event_name(#arg_tuple, tx)) => {
225                    let mut result = actor.#event_name #arg_tuple #aw;
226                    #handle_error;
227                    let _ = tx.send(result);
228                }
229            });
230        } else {
231            event_variants.push(quote! { #event_name(#arg_typle_type) });
232            run_arms.push(quote! {
233                Ok(#events_ident::#event_name(#arg_tuple)) => {
234                    let mut result = actor.#event_name #arg_tuple #aw;
235                    #handle_error;
236                }
237            });
238        }
239    }
240
241    let select = match sel {
242        None => quote! {
243            let sel = std::future::pending::<(#events_ident #ty_generics, usize, Vec<Selection>)>();
244        },
245        Some((sel, false)) => quote! {
246            let futures: Vec<Selection> = actor.#sel();
247            let sel = futures::future::select_all(futures);
248        },
249        Some((sel, true)) => quote! {
250            let futures: Vec<Selection> = actor.#sel().await;
251            let sel = futures::future::select_all(futures);
252        },
253    };
254
255    input.items.push(parse2(quote! {
256        fn run<F>(init: F) -> (impl Future<Output = ()>, #handle_ident #ty_generics)
257        where
258            F: FnOnce(#handle_ident #ty_generics) -> Self,
259        {
260            use futures::FutureExt as _;
261            let (tx, mut rx) = futures::channel::mpsc::unbounded();
262            let handle = #handle_ident(tx);
263            let mut actor = init(handle.clone());
264            let handle2 = handle.clone();
265            let future = async move {
266                loop {
267                    #select
268                    futures::select_biased! {
269                        event = rx.recv() => {
270                            match event {
271                                #(#run_arms),*
272                                Ok(#events_ident::__sactor_stop) | Err(_) => break,
273                                Ok(#events_ident::__sactor_phantom(_)) => unreachable!(),
274                            }
275                        }
276                        event = async { sel.await.0 }.fuse() => {
277                            handle2.0.unbounded_send(event).unwrap();
278                        }
279                    }
280                }
281            };
282            (future, handle)
283        }
284    })?);
285
286    let call_error_handler = match error_handler {
287        None => quote! {},
288        Some((error_handler, false)) => quote! {
289            self.#error_handler(error);
290        },
291        Some((error_handler, true)) => quote! {
292            self.#error_handler(error).await;
293        },
294    };
295    input.items.push(parse2(quote! {
296        async fn __sactor_handle_error(&mut self, error: &mut anyhow::Error) {
297            #call_error_handler
298        }
299    })?);
300
301    Ok(quote! {
302        type Selection<'a> = std::pin::Pin<Box<dyn Future<Output = #events_ident #ty_generics> + Send + 'a>>;
303
304        #[allow(unused_macros)]
305        macro_rules! selection {
306            ($expression:expr, $variant:ident) => {
307                Box::pin(async { $expression; #events_ident::$variant(()) }) as Selection
308            };
309            ($expression:expr, $variant:ident, $name:pat => $($arg:tt)*) => {
310                Box::pin(async { let $name = $expression; #events_ident::$variant($($arg)*) }) as Selection
311            };
312        }
313
314        #input
315
316        #[allow(non_camel_case_types)]
317        enum #events_ident #impl_generics #where_clause {
318            __sactor_stop,
319            __sactor_phantom(std::marker::PhantomData<(#(#type_params),*)>),
320            #(#event_variants),*
321        }
322
323        #[derive(Clone)]
324        #handle_vis struct #handle_ident #impl_generics #where_clause (futures::channel::mpsc::UnboundedSender<#events_ident #ty_generics>);
325        impl #impl_generics #handle_ident #ty_generics #where_clause {
326            #(#handle_items)*
327
328            #handle_vis fn is_running(&self) -> bool {
329                !self.0.is_closed()
330            }
331
332            #handle_vis fn stop(&self) {
333                let _ = self.0.unbounded_send(#events_ident::__sactor_stop);
334            }
335        }
336    })
337}