1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4 parse::Parse, parse_macro_input, Attribute, FnArg, GenericArgument, Ident, ImplItem,
5 ImplItemFn, ItemImpl, ItemTrait, Pat, PathArguments, ReturnType, TraitItem, Type, TypePath,
6};
7
8fn to_snake_case(s: &str) -> String {
11 let mut result = String::new();
12 let chars: Vec<char> = s.chars().collect();
13 for (i, &ch) in chars.iter().enumerate() {
14 if ch.is_uppercase() {
15 if i > 0 {
19 let prev_lower = chars[i - 1].is_lowercase();
20 let next_lower = chars.get(i + 1).is_some_and(|c| c.is_lowercase());
21 if prev_lower || next_lower {
22 result.push('_');
23 }
24 }
25 result.push(ch.to_ascii_lowercase());
26 } else {
27 result.push(ch);
28 }
29 }
30 result
31}
32
33fn to_pascal_case(s: &str) -> String {
34 s.split('_')
35 .filter(|part| !part.is_empty())
36 .map(|part| {
37 let mut chars = part.chars();
38 match chars.next() {
39 None => String::new(),
40 Some(c) => c.to_uppercase().to_string() + chars.as_str(),
41 }
42 })
43 .collect()
44}
45
46fn strip_protocol_suffix(name: &str) -> String {
47 name.strip_suffix("Protocol").unwrap_or(name).to_string()
48}
49
50enum MethodKind {
51 Send,
52 Request(Box<Type>),
53}
54
55#[derive(Clone, Copy)]
56enum RuntimeMode {
57 Tasks,
58 Threads,
59}
60
61fn classify_return_type(ret: &ReturnType) -> Result<MethodKind, &Type> {
62 match ret {
63 ReturnType::Default => Ok(MethodKind::Send),
64 ReturnType::Type(_, ty) => {
65 if is_unit_type(ty) {
66 return Ok(MethodKind::Send);
67 }
68 if let Some(inner) = extract_response_inner(ty) {
69 return Ok(MethodKind::Request(inner));
70 }
71 if let Some(inner) = extract_result_inner(ty) {
72 if is_unit_type(&inner) {
73 return Ok(MethodKind::Send);
74 }
75 return Err(ty);
78 }
79 Err(ty)
80 }
81 }
82}
83
84fn extract_response_inner(ty: &Type) -> Option<Box<Type>> {
85 if let Type::Path(TypePath { path, .. }) = ty {
86 let seg = path.segments.last()?;
87 if seg.ident == "Response" {
88 if let PathArguments::AngleBracketed(args) = &seg.arguments {
89 if let Some(GenericArgument::Type(inner)) = args.args.first() {
90 return Some(Box::new(inner.clone()));
91 }
92 }
93 }
94 }
95 None
96}
97
98fn extract_result_inner(ty: &Type) -> Option<Box<Type>> {
99 if let Type::Path(TypePath { path, .. }) = ty {
100 let seg = path.segments.last()?;
101 if seg.ident == "Result" {
102 if let PathArguments::AngleBracketed(args) = &seg.arguments {
103 if let Some(GenericArgument::Type(inner)) = args.args.first() {
104 return Some(Box::new(inner.clone()));
105 }
106 }
107 }
108 }
109 None
110}
111
112fn is_unit_type(ty: &Type) -> bool {
113 if let Type::Tuple(tuple) = ty {
114 return tuple.elems.is_empty();
115 }
116 false
117}
118
119fn is_prelude_or_primitive(name: &str) -> bool {
122 matches!(
123 name,
124 "bool" | "char" | "str"
126 | "i8" | "i16" | "i32" | "i64" | "i128" | "isize"
127 | "u8" | "u16" | "u32" | "u64" | "u128" | "usize"
128 | "f32" | "f64"
129 | "Box" | "String" | "Vec"
131 | "Option" | "Some" | "None"
132 | "Result" | "Ok" | "Err"
133 | "ToString" | "ToOwned"
134 )
135}
136
137fn qualify_type_with_super(ty: &Type) -> Type {
145 match ty {
146 Type::Path(TypePath { qself, path }) => {
147 if qself.is_some() || path.leading_colon.is_some() {
149 return ty.clone();
150 }
151 if let Some(first) = path.segments.first() {
152 let s = first.ident.to_string();
153 if matches!(
154 s.as_str(),
155 "crate" | "super" | "self" | "std" | "core" | "alloc"
156 ) {
157 return ty.clone();
158 }
159 }
160
161 let qualified_segments: syn::punctuated::Punctuated<_, _> = path
163 .segments
164 .iter()
165 .map(|seg| {
166 let mut new_seg = seg.clone();
167 if let PathArguments::AngleBracketed(ref mut args) = new_seg.arguments {
168 for arg in &mut args.args {
169 if let GenericArgument::Type(ref mut inner) = arg {
170 *inner = qualify_type_with_super(inner);
171 }
172 }
173 }
174 new_seg
175 })
176 .collect();
177
178 if let Some(first) = path.segments.first() {
180 if is_prelude_or_primitive(&first.ident.to_string()) {
181 return Type::Path(TypePath {
182 qself: None,
183 path: syn::Path {
184 leading_colon: None,
185 segments: qualified_segments,
186 },
187 });
188 }
189 }
190
191 let mut segments = syn::punctuated::Punctuated::new();
193 segments.push(syn::PathSegment {
194 ident: format_ident!("super"),
195 arguments: PathArguments::None,
196 });
197 for seg in qualified_segments {
198 segments.push(seg);
199 }
200
201 Type::Path(TypePath {
202 qself: None,
203 path: syn::Path {
204 leading_colon: None,
205 segments,
206 },
207 })
208 }
209 Type::Reference(r) => {
210 let mut new = r.clone();
211 new.elem = Box::new(qualify_type_with_super(&r.elem));
212 Type::Reference(new)
213 }
214 Type::Tuple(t) => {
215 let mut new = t.clone();
216 for elem in &mut new.elems {
217 *elem = qualify_type_with_super(elem);
218 }
219 Type::Tuple(new)
220 }
221 Type::Array(a) => {
222 let mut new = a.clone();
223 *new.elem = qualify_type_with_super(&a.elem);
224 Type::Array(new)
225 }
226 Type::Slice(s) => {
227 let mut new = s.clone();
228 *new.elem = qualify_type_with_super(&s.elem);
229 Type::Slice(new)
230 }
231 Type::Paren(p) => {
232 let mut new = p.clone();
233 *new.elem = qualify_type_with_super(&p.elem);
234 Type::Paren(new)
235 }
236 Type::TraitObject(t) => {
237 let mut new = t.clone();
238 for bound in &mut new.bounds {
239 if let syn::TypeParamBound::Trait(tb) = bound {
240 qualify_path_with_super(&mut tb.path);
241 }
242 }
243 Type::TraitObject(new)
244 }
245 Type::ImplTrait(t) => {
246 let mut new = t.clone();
247 for bound in &mut new.bounds {
248 if let syn::TypeParamBound::Trait(tb) = bound {
249 qualify_path_with_super(&mut tb.path);
250 }
251 }
252 Type::ImplTrait(new)
253 }
254 Type::BareFn(f) => {
255 let mut new = f.clone();
256 for arg in &mut new.inputs {
257 arg.ty = qualify_type_with_super(&arg.ty);
258 }
259 if let ReturnType::Type(_, ref mut ty) = new.output {
260 **ty = qualify_type_with_super(ty);
261 }
262 Type::BareFn(new)
263 }
264 _ => ty.clone(),
265 }
266}
267
268fn qualify_path_with_super(path: &mut syn::Path) {
271 for seg in &mut path.segments {
272 match &mut seg.arguments {
273 PathArguments::AngleBracketed(ref mut args) => {
274 for arg in &mut args.args {
275 if let GenericArgument::Type(ref mut inner) = arg {
276 *inner = qualify_type_with_super(inner);
277 }
278 }
279 }
280 PathArguments::Parenthesized(ref mut args) => {
281 for input in &mut args.inputs {
282 *input = qualify_type_with_super(input);
283 }
284 if let syn::ReturnType::Type(_, ref mut ty) = args.output {
285 **ty = qualify_type_with_super(ty);
286 }
287 }
288 PathArguments::None => {}
289 }
290 }
291}
292
293struct ProtocolInfo<'a> {
294 trait_name: &'a Ident,
295 mod_name: &'a Ident,
296 ref_name: &'a Ident,
297 converter_trait: &'a Ident,
298 converter_method: &'a Ident,
299}
300
301fn generate_blanket_impl(
308 info: &ProtocolInfo,
309 methods: &[ProtocolMethodInfo],
310 runtime_path: &proc_macro2::TokenStream,
311 mode: RuntimeMode,
312) -> proc_macro2::TokenStream {
313 let ProtocolInfo {
314 trait_name,
315 mod_name,
316 ref_name,
317 converter_trait,
318 converter_method,
319 } = info;
320
321 let handler_bounds: Vec<_> = methods
323 .iter()
324 .filter(|m| m.cfg_attrs.is_empty())
325 .map(|m| {
326 let sn = &m.struct_name;
327 quote! { #runtime_path::Handler<#mod_name::#sn> }
328 })
329 .collect();
330
331 let mut cfg_groups: Vec<(String, Vec<&ProtocolMethodInfo>)> = Vec::new();
334 for m in methods.iter().filter(|m| !m.cfg_attrs.is_empty()) {
335 let key: String = m
336 .cfg_attrs
337 .iter()
338 .map(|a| quote!(#a).to_string())
339 .collect::<Vec<_>>()
340 .join(",");
341 if let Some(group) = cfg_groups.iter_mut().find(|(k, _)| k == &key) {
342 group.1.push(m);
343 } else {
344 cfg_groups.push((key, vec![m]));
345 }
346 }
347
348 let mode_suffix = match mode {
349 RuntimeMode::Tasks => format_ident!("Tasks"),
350 RuntimeMode::Threads => format_ident!("Threads"),
351 };
352
353 let mut marker_trait_defs = Vec::new();
354 let mut marker_trait_bounds = Vec::new();
355
356 for (i, (_key, group_methods)) in cfg_groups.iter().enumerate() {
357 let marker_name = format_ident!("__{}Cfg{}{}", trait_name, i, mode_suffix);
358 let cfg_attrs = &group_methods[0].cfg_attrs;
359 let group_handler_bounds: Vec<_> = group_methods
360 .iter()
361 .map(|m| {
362 let sn = &m.struct_name;
363 quote! { #runtime_path::Handler<#mod_name::#sn> }
364 })
365 .collect();
366
367 let cfg_predicates: Vec<proc_macro2::TokenStream> = cfg_attrs
370 .iter()
371 .filter(|a| a.path().is_ident("cfg"))
372 .filter_map(|a| a.parse_args::<proc_macro2::TokenStream>().ok())
373 .collect();
374
375 let (positive_cfg, negated_cfg) = if cfg_predicates.len() == 1 {
376 let pred = &cfg_predicates[0];
377 (quote! { #[cfg(#pred)] }, quote! { #[cfg(not(#pred))] })
378 } else {
379 (
380 quote! { #[cfg(all(#(#cfg_predicates),*))] },
381 quote! { #[cfg(not(all(#(#cfg_predicates),*)))] },
382 )
383 };
384
385 marker_trait_defs.push(quote! {
386 #positive_cfg
387 #[doc(hidden)]
388 trait #marker_name: #(#group_handler_bounds)+* {}
389 #positive_cfg
390 impl<__T: #(#group_handler_bounds)+*> #marker_name for __T {}
391
392 #negated_cfg
393 #[doc(hidden)]
394 trait #marker_name {}
395 #negated_cfg
396 impl<__T> #marker_name for __T {}
397 });
398
399 marker_trait_bounds.push(quote! { + #marker_name });
400 }
401
402 let method_impls: Vec<_> = methods
405 .iter()
406 .map(|m| {
407 let method_name = &m.method_name;
408 let field_names = &m.field_names;
409 let params: Vec<_> = m.params.iter().collect();
410 let ret_ty = &m.ret_type;
411 let method_attrs = &m.method_attrs;
412
413 let struct_name = &m.struct_name;
414 let msg_construct = if field_names.is_empty() {
415 quote! { #mod_name::#struct_name }
416 } else {
417 quote! { #mod_name::#struct_name { #(#field_names),* } }
418 };
419
420 match &m.kind {
421 MethodKind::Send => {
422 let is_unit_return = match ret_ty {
423 ReturnType::Default => true,
424 ReturnType::Type(_, ty) => is_unit_type(ty),
425 };
426 let body = if is_unit_return {
427 quote! { let _ = self.send(#msg_construct); }
428 } else {
429 quote! { self.send(#msg_construct) }
430 };
431 quote! {
432 #(#method_attrs)*
433 fn #method_name(&self, #(#params),*) #ret_ty {
434 #body
435 }
436 }
437 }
438 MethodKind::Request(_) => {
439 let body = match mode {
440 RuntimeMode::Tasks => quote! {
441 spawned_concurrency::Response::from_with_timeout(
442 self.request_raw(#msg_construct),
443 spawned_concurrency::tasks::DEFAULT_REQUEST_TIMEOUT,
444 )
445 },
446 RuntimeMode::Threads => quote! {
447 spawned_concurrency::Response::ready(
448 self.request(#msg_construct),
449 )
450 },
451 };
452 quote! {
453 #(#method_attrs)*
454 fn #method_name(&self, #(#params),*) #ret_ty {
455 #body
456 }
457 }
458 }
459 }
460 })
461 .collect();
462
463 quote! {
464 #(#marker_trait_defs)*
465
466 impl<__A: #runtime_path::Actor #(+ #handler_bounds)* #(#marker_trait_bounds)*> #trait_name
467 for #runtime_path::ActorRef<__A>
468 {
469 #(#method_impls)*
470 }
471
472 impl<__A: #runtime_path::Actor #(+ #handler_bounds)* #(#marker_trait_bounds)*> #converter_trait
473 for #runtime_path::ActorRef<__A>
474 {
475 fn #converter_method(&self) -> #ref_name {
476 ::std::sync::Arc::new(self.clone())
477 }
478 }
479 }
480}
481
482struct ProtocolMethodInfo {
483 method_name: Ident,
484 struct_name: Ident,
485 field_names: Vec<Ident>,
486 field_types: Vec<Type>,
487 kind: MethodKind,
488 params: Vec<FnArg>,
489 ret_type: ReturnType,
490 method_attrs: Vec<Attribute>,
492 cfg_attrs: Vec<Attribute>,
493}
494
495#[proc_macro_attribute]
561pub fn protocol(_attr: TokenStream, item: TokenStream) -> TokenStream {
562 let trait_def = parse_macro_input!(item as ItemTrait);
563 let trait_name = &trait_def.ident;
564 let trait_vis = &trait_def.vis;
565
566 if !trait_def.generics.params.is_empty() {
567 return syn::Error::new_spanned(
568 &trait_def.generics,
569 "generic type parameters on protocol traits are not supported",
570 )
571 .to_compile_error()
572 .into();
573 }
574
575 let base_name = strip_protocol_suffix(&trait_name.to_string());
576 let mod_name = format_ident!("{}", to_snake_case(&trait_name.to_string()));
577 let ref_name = format_ident!("{}Ref", base_name);
578 let converter_trait = format_ident!("To{}Ref", base_name);
579 let converter_method = format_ident!("to_{}_ref", to_snake_case(&base_name));
580
581 let mut methods: Vec<ProtocolMethodInfo> = Vec::new();
582
583 for item in &trait_def.items {
584 if !matches!(item, TraitItem::Fn(_)) {
585 return syn::Error::new_spanned(
586 item,
587 "protocol traits may only contain methods; \
588 associated types, constants, and other items are not supported",
589 )
590 .to_compile_error()
591 .into();
592 }
593 if let TraitItem::Fn(method) = item {
594 if method.sig.asyncness.is_some() {
595 return syn::Error::new_spanned(
596 &method.sig,
597 "protocol methods must not be async; \
598 use Response<T> as the return type for requests",
599 )
600 .to_compile_error()
601 .into();
602 }
603
604 match method.sig.inputs.first() {
606 Some(FnArg::Receiver(r)) if r.reference.is_some() && r.mutability.is_none() => {}
607 _ => {
608 return syn::Error::new_spanned(
609 &method.sig,
610 "protocol methods must take `&self` as the first parameter",
611 )
612 .to_compile_error()
613 .into();
614 }
615 }
616
617 let method_name = method.sig.ident.clone();
618 let struct_name = format_ident!("{}", to_pascal_case(&method_name.to_string()));
619
620 let mut field_names: Vec<Ident> = Vec::new();
621 let mut field_types: Vec<Type> = Vec::new();
622 let mut params: Vec<FnArg> = Vec::new();
623
624 for arg in method.sig.inputs.iter().skip(1) {
625 if let FnArg::Typed(pat_type) = arg {
626 if let Pat::Ident(pat_ident) = &*pat_type.pat {
627 field_names.push(pat_ident.ident.clone());
628 field_types.push((*pat_type.ty).clone());
629 } else {
630 return syn::Error::new_spanned(
631 &pat_type.pat,
632 "protocol methods only support simple identifier patterns \
633 (e.g., `name: Type`)",
634 )
635 .to_compile_error()
636 .into();
637 }
638 }
639 params.push(arg.clone());
640 }
641
642 let kind = match classify_return_type(&method.sig.output) {
643 Ok(kind) => kind,
644 Err(ty) => {
645 return syn::Error::new_spanned(
646 ty,
647 "unsupported return type in protocol method; \
648 use Response<T> for requests (works in both async and sync modes), \
649 Result<(), ActorError> for sends, or no return type for fire-and-forget",
650 )
651 .to_compile_error()
652 .into();
653 }
654 };
655
656 let method_attrs: Vec<Attribute> = method
657 .attrs
658 .iter()
659 .filter(|a| {
660 a.path().is_ident("doc")
661 || a.path().is_ident("cfg")
662 || a.path().is_ident("cfg_attr")
663 })
664 .cloned()
665 .collect();
666 let cfg_attrs: Vec<Attribute> = method
667 .attrs
668 .iter()
669 .filter(|a| a.path().is_ident("cfg") || a.path().is_ident("cfg_attr"))
670 .cloned()
671 .collect();
672
673 methods.push(ProtocolMethodInfo {
674 method_name,
675 struct_name,
676 field_names,
677 field_types,
678 kind,
679 params,
680 ret_type: method.sig.output.clone(),
681 method_attrs,
682 cfg_attrs,
683 });
684 }
685 }
686
687 let msg_structs: Vec<_> = methods
692 .iter()
693 .map(|m| {
694 let struct_name = &m.struct_name;
695 let field_names = &m.field_names;
696 let qualified_field_types: Vec<Type> =
697 m.field_types.iter().map(qualify_type_with_super).collect();
698 let method_attrs = &m.method_attrs;
699 let cfg_attrs = &m.cfg_attrs;
700 let msg_result_ty: Type = match &m.kind {
701 MethodKind::Send => syn::parse_quote! { () },
702 MethodKind::Request(inner) => qualify_type_with_super(inner),
703 };
704
705 if field_names.is_empty() {
706 quote! {
707 #(#method_attrs)*
708 #[derive(Clone)]
709 pub struct #struct_name;
710 #(#cfg_attrs)*
711 impl Message for #struct_name {
712 type Result = #msg_result_ty;
713 }
714 }
715 } else {
716 quote! {
717 #(#method_attrs)*
718 pub struct #struct_name {
719 #(pub #field_names: #qualified_field_types,)*
720 }
721 #(#cfg_attrs)*
722 impl Message for #struct_name {
723 type Result = #msg_result_ty;
724 }
725 }
726 }
727 })
728 .collect();
729
730 let tasks = quote! { spawned_concurrency::tasks };
732 let threads = quote! { spawned_concurrency::threads };
733 let proto_info = ProtocolInfo {
734 trait_name,
735 mod_name: &mod_name,
736 ref_name: &ref_name,
737 converter_trait: &converter_trait,
738 converter_method: &converter_method,
739 };
740 let tasks_impl = generate_blanket_impl(&proto_info, &methods, &tasks, RuntimeMode::Tasks);
741 let threads_impl = generate_blanket_impl(&proto_info, &methods, &threads, RuntimeMode::Threads);
742 let blanket_impls = quote! { #tasks_impl #threads_impl };
743
744 let ref_doc = format!(
745 "Type-erased reference to any actor implementing [`{trait_name}`].\n\n\
746 Use this type to store protocol references without depending on the concrete actor type."
747 );
748
749 let output = quote! {
752 #[allow(dead_code)]
753 #trait_def
754
755 #[doc = #ref_doc]
756 #trait_vis type #ref_name = ::std::sync::Arc<dyn #trait_name>;
757
758 #trait_vis mod #mod_name {
759 use super::*;
760 use spawned_concurrency::message::Message;
761 #(#msg_structs)*
762 }
763
764 #trait_vis trait #converter_trait {
765 fn #converter_method(&self) -> #ref_name;
766 }
767
768 impl #converter_trait for #ref_name {
769 fn #converter_method(&self) -> #ref_name {
770 ::std::sync::Arc::clone(self)
771 }
772 }
773
774 #blanket_impls
775 };
776
777 output.into()
778}
779
780#[proc_macro_attribute]
856pub fn actor(attr: TokenStream, item: TokenStream) -> TokenStream {
857 let mut impl_block = parse_macro_input!(item as ItemImpl);
858
859 let self_ty = &impl_block.self_ty;
860 let (impl_generics, _, where_clause) = impl_block.generics.split_for_impl();
861
862 let bridge_traits: Vec<Ident> = if attr.is_empty() {
864 Vec::new()
865 } else {
866 let parser = |input: syn::parse::ParseStream| -> syn::Result<Vec<Ident>> {
867 let mut protocols = Vec::new();
868 while !input.is_empty() {
869 let key: Ident = input.parse()?;
870 if key != "protocol" {
871 return Err(syn::Error::new(
872 key.span(),
873 "unknown parameter, expected `protocol`",
874 ));
875 }
876 if input.peek(syn::Token![=]) {
877 let _: syn::Token![=] = input.parse()?;
879 protocols.push(input.parse()?);
880 } else {
881 let content;
883 syn::parenthesized!(content in input);
884 let punctuated = content.parse_terminated(Ident::parse, syn::Token![,])?;
885 protocols.extend(punctuated);
886 }
887 if input.peek(syn::Token![,]) {
888 let _: syn::Token![,] = input.parse()?;
889 }
890 }
891 Ok(protocols)
892 };
893 match syn::parse::Parser::parse(parser, attr) {
894 Ok(traits) => traits,
895 Err(e) => return e.to_compile_error().into(),
896 }
897 };
898
899 let mut started_method: Option<ImplItemFn> = None;
901 let mut stopped_method: Option<ImplItemFn> = None;
902 let mut has_async = false;
903
904 let mut items_to_keep = Vec::new();
905 for item in impl_block.items.drain(..) {
906 if let ImplItem::Fn(ref method) = item {
907 let is_started = method.attrs.iter().any(|a| a.path().is_ident("started"));
908 let is_stopped = method.attrs.iter().any(|a| a.path().is_ident("stopped"));
909
910 if is_started {
911 if started_method.is_some() {
912 return syn::Error::new_spanned(
913 &method.sig,
914 "only one #[started] method is allowed per actor",
915 )
916 .to_compile_error()
917 .into();
918 }
919 if method.attrs.iter().any(|a| {
920 a.path().is_ident("handler")
921 || a.path().is_ident("send_handler")
922 || a.path().is_ident("request_handler")
923 }) {
924 return syn::Error::new_spanned(
925 &method.sig,
926 "#[started] cannot be combined with handler attributes",
927 )
928 .to_compile_error()
929 .into();
930 }
931 if method.sig.inputs.len() != 2 {
933 return syn::Error::new_spanned(
934 &method.sig,
935 "#[started] method must take exactly (&mut self, &Context<Self>)",
936 )
937 .to_compile_error()
938 .into();
939 }
940 if !matches!(method.sig.inputs.first(), Some(FnArg::Receiver(r)) if r.mutability.is_some())
941 {
942 return syn::Error::new_spanned(
943 &method.sig,
944 "#[started] method's first parameter must be `&mut self`",
945 )
946 .to_compile_error()
947 .into();
948 }
949 let mut m = method.clone();
950 m.attrs.retain(|a| !a.path().is_ident("started"));
951 m.vis = syn::Visibility::Inherited;
952 m.sig.ident = format_ident!("started");
953 if m.sig.asyncness.is_some() {
954 has_async = true;
955 }
956 started_method = Some(m);
957 continue;
958 }
959
960 if is_stopped {
961 if stopped_method.is_some() {
962 return syn::Error::new_spanned(
963 &method.sig,
964 "only one #[stopped] method is allowed per actor",
965 )
966 .to_compile_error()
967 .into();
968 }
969 if method.attrs.iter().any(|a| {
970 a.path().is_ident("handler")
971 || a.path().is_ident("send_handler")
972 || a.path().is_ident("request_handler")
973 }) {
974 return syn::Error::new_spanned(
975 &method.sig,
976 "#[stopped] cannot be combined with handler attributes",
977 )
978 .to_compile_error()
979 .into();
980 }
981 if method.sig.inputs.len() != 2 {
983 return syn::Error::new_spanned(
984 &method.sig,
985 "#[stopped] method must take exactly (&mut self, &Context<Self>)",
986 )
987 .to_compile_error()
988 .into();
989 }
990 if !matches!(method.sig.inputs.first(), Some(FnArg::Receiver(r)) if r.mutability.is_some())
991 {
992 return syn::Error::new_spanned(
993 &method.sig,
994 "#[stopped] method's first parameter must be `&mut self`",
995 )
996 .to_compile_error()
997 .into();
998 }
999 let mut m = method.clone();
1000 m.attrs.retain(|a| !a.path().is_ident("stopped"));
1001 m.vis = syn::Visibility::Inherited;
1002 m.sig.ident = format_ident!("stopped");
1003 if m.sig.asyncness.is_some() {
1004 has_async = true;
1005 }
1006 stopped_method = Some(m);
1007 continue;
1008 }
1009 }
1010 items_to_keep.push(item);
1011 }
1012 impl_block.items = items_to_keep;
1013
1014 let mut handler_impls = Vec::new();
1016
1017 for item in &mut impl_block.items {
1018 if let ImplItem::Fn(method) = item {
1019 let handler_idx = method.attrs.iter().position(|attr| {
1020 attr.path().is_ident("handler")
1021 || attr.path().is_ident("send_handler")
1022 || attr.path().is_ident("request_handler")
1023 });
1024
1025 if let Some(idx) = handler_idx {
1026 method.attrs.remove(idx);
1027
1028 let extra_attrs: Vec<_> = method
1031 .attrs
1032 .iter()
1033 .filter(|a| {
1034 !a.path().is_ident("handler")
1035 && !a.path().is_ident("send_handler")
1036 && !a.path().is_ident("request_handler")
1037 && !a.path().is_ident("started")
1038 && !a.path().is_ident("stopped")
1039 })
1040 .cloned()
1041 .collect();
1042
1043 let method_name = &method.sig.ident;
1044 if method.sig.asyncness.is_some() {
1045 has_async = true;
1046 }
1047
1048 let param_count = method.sig.inputs.len();
1050 if param_count != 3 {
1051 return syn::Error::new_spanned(
1052 &method.sig,
1053 format!(
1054 "handler method must have 3 parameters (&mut self, msg: M, ctx: &Context<Self>), found {param_count}"
1055 ),
1056 )
1057 .to_compile_error()
1058 .into();
1059 }
1060
1061 let msg_ty = match method.sig.inputs.iter().nth(1) {
1063 Some(FnArg::Typed(pat_type)) => &*pat_type.ty,
1064 _ => {
1065 return syn::Error::new_spanned(
1066 &method.sig,
1067 "handler method must have signature: fn(&mut self, msg: M, ctx: &Context<Self>) -> R",
1068 )
1069 .to_compile_error()
1070 .into();
1071 }
1072 };
1073
1074 let ret_ty: Box<Type> = match &method.sig.output {
1076 ReturnType::Default => syn::parse_quote! { () },
1077 ReturnType::Type(_, ty) => ty.clone(),
1078 };
1079
1080 let handler_impl = if method.sig.asyncness.is_some() {
1081 quote! {
1082 #(#extra_attrs)*
1083 impl #impl_generics Handler<#msg_ty> for #self_ty #where_clause {
1084 async fn handle(&mut self, msg: #msg_ty, ctx: &Context<Self>) -> #ret_ty {
1085 self.#method_name(msg, ctx).await
1086 }
1087 }
1088 }
1089 } else {
1090 quote! {
1091 #(#extra_attrs)*
1092 impl #impl_generics Handler<#msg_ty> for #self_ty #where_clause {
1093 fn handle(&mut self, msg: #msg_ty, ctx: &Context<Self>) -> #ret_ty {
1094 self.#method_name(msg, ctx)
1095 }
1096 }
1097 }
1098 };
1099
1100 handler_impls.push(handler_impl);
1101 }
1102 }
1103 }
1104
1105 let lifecycle_methods: Vec<&ImplItemFn> = [started_method.as_ref(), stopped_method.as_ref()]
1107 .into_iter()
1108 .flatten()
1109 .collect();
1110
1111 let protocol_doc = if bridge_traits.is_empty() {
1112 quote! {}
1113 } else {
1114 let lines: Vec<String> = bridge_traits.iter().map(|t| format!("- [`{t}`]")).collect();
1115 let doc_body = format!(
1116 "# Protocol\n\n\
1117 When started, `ActorRef<{ty}>` implements:\n\n\
1118 {lines}\n\n\
1119 See the protocol trait docs for the full API.",
1120 ty = quote!(#self_ty),
1121 lines = lines.join("\n"),
1122 );
1123 quote! { #[doc = #doc_body] }
1124 };
1125
1126 let actor_impl = quote! {
1127 #protocol_doc
1128 impl #impl_generics Actor for #self_ty #where_clause {
1129 #(#lifecycle_methods)*
1130 }
1131 };
1132
1133 let runtime_path = if has_async {
1135 quote! { spawned_concurrency::tasks }
1136 } else {
1137 quote! { spawned_concurrency::threads }
1138 };
1139
1140 let bridge_asserts: Vec<_> = bridge_traits
1141 .iter()
1142 .map(|trait_name| {
1143 quote! {
1144 const _: () = {
1145 fn _assert_bridge<__T: #trait_name>() {}
1146 fn _check() {
1147 _assert_bridge::<#runtime_path::ActorRef<#self_ty>>();
1148 }
1149 };
1150 }
1151 })
1152 .collect();
1153
1154 let output = quote! {
1155 #actor_impl
1156 #impl_block
1157 #(#handler_impls)*
1158 #(#bridge_asserts)*
1159 };
1160
1161 output.into()
1162}