1use manyhow::manyhow;
2use proc_macro2::TokenStream;
3use quote::quote;
4use syn::{
5 Error, FnArg, GenericParam, Ident, ImplItem, ImplItemFn, ItemImpl, Pat, PatIdent, Result,
6 ReturnType, Type, Visibility, parse2, spanned::Spanned,
7};
8
9#[manyhow]
10#[proc_macro_attribute]
11pub fn sactor(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
12 let handle_vis: Visibility = if attr.is_empty() {
13 Visibility::Inherited
14 } else {
15 parse2(attr)?
16 };
17 let mut input: ItemImpl = parse2(item)?;
18 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
19
20 let self_ident = {
21 let Type::Path(path) = input.self_ty.as_ref() else {
22 return Err(Error::new_spanned(&input.self_ty, "expected a path"));
23 };
24 path.path.segments.last().unwrap().ident.clone()
25 };
26 let handle_ident = Ident::new(&format!("{}Handle", self_ident), self_ident.span());
27 let events_ident = Ident::new(&format!("{}Events", self_ident), self_ident.span());
28
29 let type_params: Vec<_> = input
30 .generics
31 .params
32 .iter()
33 .filter_map(|p| {
34 if let GenericParam::Type(tp) = p {
35 Some(&tp.ident)
36 } else {
37 None
38 }
39 })
40 .collect();
41
42 let mut event_variants = Vec::new();
43 let mut handle_items = Vec::new();
44 let mut run_arms = Vec::new();
45 let mut sel = None; let mut error_handler = None;
47 for item in &mut input.items {
48 let ImplItem::Fn(ImplItemFn {
49 attrs, vis, sig, ..
50 }) = item
51 else {
52 continue;
53 };
54 if sig.inputs.is_empty() {
55 continue;
56 }
57 match sig.inputs.first().unwrap() {
58 FnArg::Typed(_) => continue,
59 FnArg::Receiver(receiver) if receiver.reference.is_none() => continue,
60 _ => {}
61 }
62
63 let mut skip = false;
64 let mut reply = None;
65 let mut select = false;
66 let mut error = false;
67 attrs.retain(|attr| {
68 let path = attr.meta.path();
69 if path.is_ident("skip") {
70 skip = true;
71 return false;
72 }
73 if path.is_ident("reply") {
74 reply = Some(true);
75 return false;
76 }
77 if path.is_ident("no_reply") {
78 reply = Some(false);
79 return false;
80 }
81 if path.is_ident("select") {
82 select = true;
83 return false;
84 }
85 if path.is_ident("handle_error") {
86 error = true;
87 return false;
88 }
89 true
90 });
91 if select {
92 if sel.is_some() {
93 return Err(Error::new_spanned(
94 &sig.ident,
95 "multiple select methods are not allowed",
96 ));
97 }
98 sel = Some((sig.ident.clone(), sig.asyncness.is_some()));
99 continue;
100 }
101 if error {
102 if error_handler.is_some() {
103 return Err(Error::new_spanned(
104 &sig.ident,
105 "multiple error handler methods are not allowed",
106 ));
107 }
108 error_handler = Some((sig.ident.clone(), sig.asyncness.is_some()));
109 continue;
110 }
111 if skip {
112 continue;
113 }
114
115 if !sig.generics.params.is_empty() {
117 return Err(Error::new_spanned(
118 &sig.generics,
119 "should not have method-level generics",
120 ));
121 }
122
123 let mut handle_error = false;
125 let output = match &sig.output {
126 ReturnType::Default => quote! { () },
127 ReturnType::Type(_, ty) => {
128 if reply.is_none() {
129 reply = Some(true);
130 }
131 if let Type::Path(path) = ty.as_ref() {
132 let Some(last) = path.path.segments.last() else {
133 return Err(Error::new_spanned(
134 &path.path,
135 "expected a path with segments",
136 ));
137 };
138 if last.ident == "Result" {
139 handle_error = true;
140 }
141 }
142 if let Some(false) = reply {
143 quote! { () }
144 } else {
145 quote! { #ty }
146 }
147 }
148 };
149 let mut handle_sig = sig.clone();
150 handle_sig.asyncness = Some(parse2(quote! { async })?);
151 handle_sig.output = parse2(quote! { -> anyhow::Result<#output> })?;
152
153 let mut arg_types = Vec::new();
155 let mut arg_names = Vec::new();
156 for (i, arg) in &mut handle_sig.inputs.iter_mut().enumerate() {
157 let arg = match arg {
158 FnArg::Typed(arg) => arg,
159 FnArg::Receiver(arg) => {
160 arg.mutability = None;
161 let Type::Reference(reference) = arg.ty.as_mut() else {
162 return Err(Error::new_spanned(&arg.ty, "expected a reference"));
163 };
164 reference.mutability = None;
165 continue;
166 }
167 };
168 arg_types.push(arg.ty.clone());
169 let arg_name = format!("arg{}", i);
170 arg_names.push(Ident::new(&arg_name, arg.pat.span()));
171 *arg.pat = Pat::Ident(PatIdent {
172 attrs: Vec::new(),
173 by_ref: None,
174 mutability: None,
175 ident: Ident::new(&arg_name, arg.pat.span()),
176 subpat: None,
177 });
178 }
179
180 let event_name = &sig.ident;
182 let arg_typle_type = quote! { (#(#arg_types),*) };
183 let arg_tuple = quote! { (#(#arg_names),*) };
184
185 let f = if reply.unwrap_or(false) {
186 quote! {
187 #vis #handle_sig {
188 let (tx, rx) = futures::channel::oneshot::channel();
189 self.0.unbounded_send(#events_ident::#event_name(#arg_tuple, tx))
190 .map_err(|_| sactor::error::SactorError::ActorStopped)?;
191 #[allow(clippy::needless_question_mark)]
192 Ok(rx.await.map_err(|_| sactor::error::SactorError::ActorStopped)?)
193 }
194 }
195 } else {
196 quote! {
197 #vis #handle_sig {
198 self.0.unbounded_send(#events_ident::#event_name(#arg_tuple))
199 .map_err(|_| sactor::error::SactorError::ActorStopped)?;
200 Ok(())
201 }
202 }
203 };
204
205 handle_items.push(f);
206
207 let aw = match sig.asyncness {
208 None => quote! {},
209 Some(_) => quote! { .await },
210 };
211 let handle_error = match handle_error {
212 false => quote! {},
213 true => quote! {
214 if let Err(e) = &mut result {
215 actor.__sactor_handle_error(e).await;
216 }
217 },
218 };
219 if reply.unwrap_or(false) {
220 event_variants.push(
221 quote! { #event_name(#arg_typle_type, futures::channel::oneshot::Sender<#output>) },
222 );
223 run_arms.push(quote! {
224 Ok(#events_ident::#event_name(#arg_tuple, tx)) => {
225 let mut result = actor.#event_name #arg_tuple #aw;
226 #handle_error;
227 let _ = tx.send(result);
228 }
229 });
230 } else {
231 event_variants.push(quote! { #event_name(#arg_typle_type) });
232 run_arms.push(quote! {
233 Ok(#events_ident::#event_name(#arg_tuple)) => {
234 let mut result = actor.#event_name #arg_tuple #aw;
235 #handle_error;
236 }
237 });
238 }
239 }
240
241 let select = match sel {
242 None => quote! {
243 let sel = std::future::pending::<(#events_ident #ty_generics, usize, Vec<Selection>)>();
244 },
245 Some((sel, false)) => quote! {
246 let futures: Vec<Selection> = actor.#sel();
247 let sel = futures::future::select_all(futures);
248 },
249 Some((sel, true)) => quote! {
250 let futures: Vec<Selection> = actor.#sel().await;
251 let sel = futures::future::select_all(futures);
252 },
253 };
254
255 input.items.push(parse2(quote! {
256 fn run<F>(init: F) -> (impl Future<Output = ()>, #handle_ident #ty_generics)
257 where
258 F: FnOnce(#handle_ident #ty_generics) -> Self,
259 {
260 use futures::FutureExt as _;
261 let (tx, mut rx) = futures::channel::mpsc::unbounded();
262 let handle = #handle_ident(tx);
263 let mut actor = init(handle.clone());
264 let handle2 = handle.clone();
265 let future = async move {
266 loop {
267 #select
268 futures::select_biased! {
269 event = rx.recv() => {
270 match event {
271 #(#run_arms),*
272 Ok(#events_ident::__sactor_stop) | Err(_) => break,
273 Ok(#events_ident::__sactor_phantom(_)) => unreachable!(),
274 }
275 }
276 event = async { sel.await.0 }.fuse() => {
277 handle2.0.unbounded_send(event).unwrap();
278 }
279 }
280 }
281 };
282 (future, handle)
283 }
284 })?);
285
286 let call_error_handler = match error_handler {
287 None => quote! {},
288 Some((error_handler, false)) => quote! {
289 self.#error_handler(error);
290 },
291 Some((error_handler, true)) => quote! {
292 self.#error_handler(error).await;
293 },
294 };
295 input.items.push(parse2(quote! {
296 async fn __sactor_handle_error(&mut self, error: &mut anyhow::Error) {
297 #call_error_handler
298 }
299 })?);
300
301 Ok(quote! {
302 type Selection<'a> = std::pin::Pin<Box<dyn Future<Output = #events_ident #ty_generics> + Send + 'a>>;
303
304 #[allow(unused_macros)]
305 macro_rules! selection {
306 ($expression:expr, $variant:ident) => {
307 Box::pin(async { $expression; #events_ident::$variant(()) }) as Selection
308 };
309 ($expression:expr, $variant:ident, $name:pat => $($arg:tt)*) => {
310 Box::pin(async { let $name = $expression; #events_ident::$variant($($arg)*) }) as Selection
311 };
312 }
313
314 #input
315
316 #[allow(non_camel_case_types)]
317 enum #events_ident #impl_generics #where_clause {
318 __sactor_stop,
319 __sactor_phantom(std::marker::PhantomData<(#(#type_params),*)>),
320 #(#event_variants),*
321 }
322
323 #[derive(Clone)]
324 #handle_vis struct #handle_ident #impl_generics #where_clause (futures::channel::mpsc::UnboundedSender<#events_ident #ty_generics>);
325 impl #impl_generics #handle_ident #ty_generics #where_clause {
326 #(#handle_items)*
327
328 #handle_vis fn is_running(&self) -> bool {
329 !self.0.is_closed()
330 }
331
332 #handle_vis fn stop(&self) {
333 let _ = self.0.unbounded_send(#events_ident::__sactor_stop);
334 }
335 }
336 })
337}