Skip to main content

spawned_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{
4    parse::Parse, parse_macro_input, Attribute, FnArg, GenericArgument, Ident, ImplItem,
5    ImplItemFn, ItemImpl, ItemTrait, Pat, PathArguments, ReturnType, TraitItem, Type, TypePath,
6};
7
8// --- Helpers for #[protocol] ---
9
10fn to_snake_case(s: &str) -> String {
11    let mut result = String::new();
12    let chars: Vec<char> = s.chars().collect();
13    for (i, &ch) in chars.iter().enumerate() {
14        if ch.is_uppercase() {
15            // Insert underscore before uppercase if:
16            // - not at start, AND
17            // - previous char is lowercase, OR next char is lowercase (handles acronyms)
18            if i > 0 {
19                let prev_lower = chars[i - 1].is_lowercase();
20                let next_lower = chars.get(i + 1).is_some_and(|c| c.is_lowercase());
21                if prev_lower || next_lower {
22                    result.push('_');
23                }
24            }
25            result.push(ch.to_ascii_lowercase());
26        } else {
27            result.push(ch);
28        }
29    }
30    result
31}
32
33fn to_pascal_case(s: &str) -> String {
34    s.split('_')
35        .filter(|part| !part.is_empty())
36        .map(|part| {
37            let mut chars = part.chars();
38            match chars.next() {
39                None => String::new(),
40                Some(c) => c.to_uppercase().to_string() + chars.as_str(),
41            }
42        })
43        .collect()
44}
45
46fn strip_protocol_suffix(name: &str) -> String {
47    name.strip_suffix("Protocol").unwrap_or(name).to_string()
48}
49
50enum MethodKind {
51    Send,
52    Request(Box<Type>),
53}
54
55#[derive(Clone, Copy)]
56enum RuntimeMode {
57    Tasks,
58    Threads,
59}
60
61fn classify_return_type(ret: &ReturnType) -> Result<MethodKind, &Type> {
62    match ret {
63        ReturnType::Default => Ok(MethodKind::Send),
64        ReturnType::Type(_, ty) => {
65            if is_unit_type(ty) {
66                return Ok(MethodKind::Send);
67            }
68            if let Some(inner) = extract_response_inner(ty) {
69                return Ok(MethodKind::Request(inner));
70            }
71            if let Some(inner) = extract_result_inner(ty) {
72                if is_unit_type(&inner) {
73                    return Ok(MethodKind::Send);
74                }
75                // Result<T, ActorError> where T ≠ () is no longer supported;
76                // use Response<T> which works in both modes.
77                return Err(ty);
78            }
79            Err(ty)
80        }
81    }
82}
83
84fn extract_response_inner(ty: &Type) -> Option<Box<Type>> {
85    if let Type::Path(TypePath { path, .. }) = ty {
86        let seg = path.segments.last()?;
87        if seg.ident == "Response" {
88            if let PathArguments::AngleBracketed(args) = &seg.arguments {
89                if let Some(GenericArgument::Type(inner)) = args.args.first() {
90                    return Some(Box::new(inner.clone()));
91                }
92            }
93        }
94    }
95    None
96}
97
98fn extract_result_inner(ty: &Type) -> Option<Box<Type>> {
99    if let Type::Path(TypePath { path, .. }) = ty {
100        let seg = path.segments.last()?;
101        if seg.ident == "Result" {
102            if let PathArguments::AngleBracketed(args) = &seg.arguments {
103                if let Some(GenericArgument::Type(inner)) = args.args.first() {
104                    return Some(Box::new(inner.clone()));
105                }
106            }
107        }
108    }
109    None
110}
111
112fn is_unit_type(ty: &Type) -> bool {
113    if let Type::Tuple(tuple) = ty {
114        return tuple.elems.is_empty();
115    }
116    false
117}
118
119/// Returns true for types available without explicit import (prelude + primitives).
120/// These can't be accessed via `super::` so must be left unqualified.
121fn is_prelude_or_primitive(name: &str) -> bool {
122    matches!(
123        name,
124        // Primitives
125        "bool" | "char" | "str"
126            | "i8" | "i16" | "i32" | "i64" | "i128" | "isize"
127            | "u8" | "u16" | "u32" | "u64" | "u128" | "usize"
128            | "f32" | "f64"
129            // Prelude types (Rust 2021)
130            | "Box" | "String" | "Vec"
131            | "Option" | "Some" | "None"
132            | "Result" | "Ok" | "Err"
133            | "ToString" | "ToOwned"
134    )
135}
136
137/// Qualify a type with `super::` so it resolves to the parent module's scope.
138/// This prevents name collisions when a generated struct name shadows an imported type
139/// via `use super::*`.
140///
141/// Prelude/primitive types are left unqualified (they can't be accessed via `super::`).
142/// User-defined types get `super::` prepended: `Event` → `super::Event`.
143/// Generic args are recursively qualified: `Vec<Event>` → `Vec<super::Event>`.
144fn qualify_type_with_super(ty: &Type) -> Type {
145    match ty {
146        Type::Path(TypePath { qself, path }) => {
147            // Leave qualified paths as-is: <X as T>::Y, ::abs::path, crate::, super::, self::
148            if qself.is_some() || path.leading_colon.is_some() {
149                return ty.clone();
150            }
151            if let Some(first) = path.segments.first() {
152                let s = first.ident.to_string();
153                if matches!(
154                    s.as_str(),
155                    "crate" | "super" | "self" | "std" | "core" | "alloc"
156                ) {
157                    return ty.clone();
158                }
159            }
160
161            // Recursively qualify generic arguments in all segments
162            let qualified_segments: syn::punctuated::Punctuated<_, _> = path
163                .segments
164                .iter()
165                .map(|seg| {
166                    let mut new_seg = seg.clone();
167                    if let PathArguments::AngleBracketed(ref mut args) = new_seg.arguments {
168                        for arg in &mut args.args {
169                            if let GenericArgument::Type(ref mut inner) = arg {
170                                *inner = qualify_type_with_super(inner);
171                            }
172                        }
173                    }
174                    new_seg
175                })
176                .collect();
177
178            // Only prepend super:: for non-prelude types
179            if let Some(first) = path.segments.first() {
180                if is_prelude_or_primitive(&first.ident.to_string()) {
181                    return Type::Path(TypePath {
182                        qself: None,
183                        path: syn::Path {
184                            leading_colon: None,
185                            segments: qualified_segments,
186                        },
187                    });
188                }
189            }
190
191            // Prepend super:: for user-defined types
192            let mut segments = syn::punctuated::Punctuated::new();
193            segments.push(syn::PathSegment {
194                ident: format_ident!("super"),
195                arguments: PathArguments::None,
196            });
197            for seg in qualified_segments {
198                segments.push(seg);
199            }
200
201            Type::Path(TypePath {
202                qself: None,
203                path: syn::Path {
204                    leading_colon: None,
205                    segments,
206                },
207            })
208        }
209        Type::Reference(r) => {
210            let mut new = r.clone();
211            new.elem = Box::new(qualify_type_with_super(&r.elem));
212            Type::Reference(new)
213        }
214        Type::Tuple(t) => {
215            let mut new = t.clone();
216            for elem in &mut new.elems {
217                *elem = qualify_type_with_super(elem);
218            }
219            Type::Tuple(new)
220        }
221        Type::Array(a) => {
222            let mut new = a.clone();
223            *new.elem = qualify_type_with_super(&a.elem);
224            Type::Array(new)
225        }
226        Type::Slice(s) => {
227            let mut new = s.clone();
228            *new.elem = qualify_type_with_super(&s.elem);
229            Type::Slice(new)
230        }
231        Type::Paren(p) => {
232            let mut new = p.clone();
233            *new.elem = qualify_type_with_super(&p.elem);
234            Type::Paren(new)
235        }
236        Type::TraitObject(t) => {
237            let mut new = t.clone();
238            for bound in &mut new.bounds {
239                if let syn::TypeParamBound::Trait(tb) = bound {
240                    qualify_path_with_super(&mut tb.path);
241                }
242            }
243            Type::TraitObject(new)
244        }
245        Type::ImplTrait(t) => {
246            let mut new = t.clone();
247            for bound in &mut new.bounds {
248                if let syn::TypeParamBound::Trait(tb) = bound {
249                    qualify_path_with_super(&mut tb.path);
250                }
251            }
252            Type::ImplTrait(new)
253        }
254        Type::BareFn(f) => {
255            let mut new = f.clone();
256            for arg in &mut new.inputs {
257                arg.ty = qualify_type_with_super(&arg.ty);
258            }
259            if let ReturnType::Type(_, ref mut ty) = new.output {
260                **ty = qualify_type_with_super(ty);
261            }
262            Type::BareFn(new)
263        }
264        _ => ty.clone(),
265    }
266}
267
268/// Qualify a path's generic arguments with `super::`, used for trait bounds
269/// in `dyn Trait` and `impl Trait` types.
270fn qualify_path_with_super(path: &mut syn::Path) {
271    for seg in &mut path.segments {
272        match &mut seg.arguments {
273            PathArguments::AngleBracketed(ref mut args) => {
274                for arg in &mut args.args {
275                    if let GenericArgument::Type(ref mut inner) = arg {
276                        *inner = qualify_type_with_super(inner);
277                    }
278                }
279            }
280            PathArguments::Parenthesized(ref mut args) => {
281                for input in &mut args.inputs {
282                    *input = qualify_type_with_super(input);
283                }
284                if let syn::ReturnType::Type(_, ref mut ty) = args.output {
285                    **ty = qualify_type_with_super(ty);
286                }
287            }
288            PathArguments::None => {}
289        }
290    }
291}
292
293struct ProtocolInfo<'a> {
294    trait_name: &'a Ident,
295    mod_name: &'a Ident,
296    ref_name: &'a Ident,
297    converter_trait: &'a Ident,
298    converter_method: &'a Ident,
299}
300
301/// Generates a blanket `impl Protocol for ActorRef<A>` and `impl ToXRef for ActorRef<A>`
302/// for a given runtime path (tasks or threads).
303///
304/// For cfg-gated methods, we generate marker traits that conditionally require
305/// the Handler bounds. This avoids putting extra where-clause bounds on individual
306/// methods (which Rust rejects as "impl has stricter requirements than trait").
307fn generate_blanket_impl(
308    info: &ProtocolInfo,
309    methods: &[ProtocolMethodInfo],
310    runtime_path: &proc_macro2::TokenStream,
311    mode: RuntimeMode,
312) -> proc_macro2::TokenStream {
313    let ProtocolInfo {
314        trait_name,
315        mod_name,
316        ref_name,
317        converter_trait,
318        converter_method,
319    } = info;
320
321    // Unconditional methods: Handler bounds go directly on the impl block.
322    let handler_bounds: Vec<_> = methods
323        .iter()
324        .filter(|m| m.cfg_attrs.is_empty())
325        .map(|m| {
326            let sn = &m.struct_name;
327            quote! { #runtime_path::Handler<#mod_name::#sn> }
328        })
329        .collect();
330
331    // Group cfg-gated methods by their cfg predicate.
332    // Each unique group gets a marker trait that conditionally requires the Handler bounds.
333    let mut cfg_groups: Vec<(String, Vec<&ProtocolMethodInfo>)> = Vec::new();
334    for m in methods.iter().filter(|m| !m.cfg_attrs.is_empty()) {
335        let key: String = m
336            .cfg_attrs
337            .iter()
338            .map(|a| quote!(#a).to_string())
339            .collect::<Vec<_>>()
340            .join(",");
341        if let Some(group) = cfg_groups.iter_mut().find(|(k, _)| k == &key) {
342            group.1.push(m);
343        } else {
344            cfg_groups.push((key, vec![m]));
345        }
346    }
347
348    let mode_suffix = match mode {
349        RuntimeMode::Tasks => format_ident!("Tasks"),
350        RuntimeMode::Threads => format_ident!("Threads"),
351    };
352
353    let mut marker_trait_defs = Vec::new();
354    let mut marker_trait_bounds = Vec::new();
355
356    for (i, (_key, group_methods)) in cfg_groups.iter().enumerate() {
357        let marker_name = format_ident!("__{}Cfg{}{}", trait_name, i, mode_suffix);
358        let cfg_attrs = &group_methods[0].cfg_attrs;
359        let group_handler_bounds: Vec<_> = group_methods
360            .iter()
361            .map(|m| {
362                let sn = &m.struct_name;
363                quote! { #runtime_path::Handler<#mod_name::#sn> }
364            })
365            .collect();
366
367        // Combine cfg predicates and negate for the fallback impl.
368        // #[cfg(A)] #[cfg(B)] -> active: all(A, B), inactive: not(all(A, B))
369        let cfg_predicates: Vec<proc_macro2::TokenStream> = cfg_attrs
370            .iter()
371            .filter(|a| a.path().is_ident("cfg"))
372            .filter_map(|a| a.parse_args::<proc_macro2::TokenStream>().ok())
373            .collect();
374
375        let (positive_cfg, negated_cfg) = if cfg_predicates.len() == 1 {
376            let pred = &cfg_predicates[0];
377            (quote! { #[cfg(#pred)] }, quote! { #[cfg(not(#pred))] })
378        } else {
379            (
380                quote! { #[cfg(all(#(#cfg_predicates),*))] },
381                quote! { #[cfg(not(all(#(#cfg_predicates),*)))] },
382            )
383        };
384
385        marker_trait_defs.push(quote! {
386            #positive_cfg
387            #[doc(hidden)]
388            trait #marker_name: #(#group_handler_bounds)+* {}
389            #positive_cfg
390            impl<__T: #(#group_handler_bounds)+*> #marker_name for __T {}
391
392            #negated_cfg
393            #[doc(hidden)]
394            trait #marker_name {}
395            #negated_cfg
396            impl<__T> #marker_name for __T {}
397        });
398
399        marker_trait_bounds.push(quote! { + #marker_name });
400    }
401
402    // Generate method implementations (no where clauses needed — bounds come from
403    // marker traits or the impl block).
404    let method_impls: Vec<_> = methods
405        .iter()
406        .map(|m| {
407            let method_name = &m.method_name;
408            let field_names = &m.field_names;
409            let params: Vec<_> = m.params.iter().collect();
410            let ret_ty = &m.ret_type;
411            let method_attrs = &m.method_attrs;
412
413            let struct_name = &m.struct_name;
414            let msg_construct = if field_names.is_empty() {
415                quote! { #mod_name::#struct_name }
416            } else {
417                quote! { #mod_name::#struct_name { #(#field_names),* } }
418            };
419
420            match &m.kind {
421                MethodKind::Send => {
422                    let is_unit_return = match ret_ty {
423                        ReturnType::Default => true,
424                        ReturnType::Type(_, ty) => is_unit_type(ty),
425                    };
426                    let body = if is_unit_return {
427                        quote! { let _ = self.send(#msg_construct); }
428                    } else {
429                        quote! { self.send(#msg_construct) }
430                    };
431                    quote! {
432                        #(#method_attrs)*
433                        fn #method_name(&self, #(#params),*) #ret_ty {
434                            #body
435                        }
436                    }
437                }
438                MethodKind::Request(_) => {
439                    let body = match mode {
440                        RuntimeMode::Tasks => quote! {
441                            spawned_concurrency::Response::from_with_timeout(
442                                self.request_raw(#msg_construct),
443                                spawned_concurrency::tasks::DEFAULT_REQUEST_TIMEOUT,
444                            )
445                        },
446                        RuntimeMode::Threads => quote! {
447                            spawned_concurrency::Response::ready(
448                                self.request(#msg_construct),
449                            )
450                        },
451                    };
452                    quote! {
453                        #(#method_attrs)*
454                        fn #method_name(&self, #(#params),*) #ret_ty {
455                            #body
456                        }
457                    }
458                }
459            }
460        })
461        .collect();
462
463    quote! {
464        #(#marker_trait_defs)*
465
466        impl<__A: #runtime_path::Actor #(+ #handler_bounds)* #(#marker_trait_bounds)*> #trait_name
467            for #runtime_path::ActorRef<__A>
468        {
469            #(#method_impls)*
470        }
471
472        impl<__A: #runtime_path::Actor #(+ #handler_bounds)* #(#marker_trait_bounds)*> #converter_trait
473            for #runtime_path::ActorRef<__A>
474        {
475            fn #converter_method(&self) -> #ref_name {
476                ::std::sync::Arc::new(self.clone())
477            }
478        }
479    }
480}
481
482struct ProtocolMethodInfo {
483    method_name: Ident,
484    struct_name: Ident,
485    field_names: Vec<Ident>,
486    field_types: Vec<Type>,
487    kind: MethodKind,
488    params: Vec<FnArg>,
489    ret_type: ReturnType,
490    /// Doc + cfg attributes to propagate to blanket impl methods.
491    method_attrs: Vec<Attribute>,
492    cfg_attrs: Vec<Attribute>,
493}
494
495/// Generates message types and blanket implementations from a protocol trait.
496///
497/// `#[protocol]` transforms a trait definition into a full message-passing interface.
498/// Each trait method becomes a message struct, and any `ActorRef<A>` where `A` handles
499/// those messages automatically implements the trait — so you call methods directly on
500/// the actor reference.
501///
502/// # What It Generates
503///
504/// For a trait `FooProtocol`, the macro generates:
505///
506/// 1. **Message structs** in a `foo_protocol` submodule — one per method, with public
507///    fields matching the method parameters. Each implements `Message`.
508///    Unit structs (methods with no parameters beyond `&self`) automatically derive `Clone`.
509/// 2. **Type alias** `pub type FooRef = Arc<dyn FooProtocol>` — a type-erased reference.
510/// 3. **Converter trait** `ToFooRef` with a `to_foo_ref(&self) -> FooRef` method.
511/// 4. **Blanket impls** — `impl FooProtocol for ActorRef<A>` and `impl ToFooRef for ActorRef<A>`
512///    for any actor `A` that handles all the generated message types.
513///
514/// # Type Resolution in Generated Structs
515///
516/// The generated message module uses `use super::*` to access types from the
517/// parent scope. Types used in method parameters are qualified with `super::`
518/// in the generated structs — this means any type you reference in a protocol
519/// method signature must be in scope where the `#[protocol]` trait is defined.
520///
521/// Prelude types (`String`, `Vec`, `Option`, `Result`, `Box`, `bool`, `u32`, etc.)
522/// and fully qualified paths (`std::`, `core::`, `alloc::`) are used as-is without
523/// `super::` prefixing.
524///
525/// # Return Type Conventions
526///
527/// The return type on each method determines the message kind:
528///
529/// | Return type | Kind | Runtime | Caller behavior |
530/// |-------------|------|---------|-----------------|
531/// | `Response<T>` | Request | Both | `.await.unwrap()` (tasks) / `.unwrap()` (threads) — 5s default timeout |
532/// | `Result<(), ActorError>` | Send | Both | Returns send result |
533/// | *(none)* / `-> ()` | Send | Both | Fire-and-forget (discards send result) |
534///
535/// # Naming
536///
537/// - **Module**: trait name → `snake_case` (e.g., `ChatRoomProtocol` → `chat_room_protocol`)
538/// - **Structs**: method name → `PascalCase` (e.g., `send_message` → `SendMessage`)
539/// - **XRef alias**: strips `Protocol` suffix → `{Base}Ref` (e.g., `ChatRoomProtocol` → `ChatRoomRef`)
540/// - **Converter**: `To{Base}Ref` with `to_{snake_base}_ref()` method
541///
542/// # Example
543///
544/// ```ignore
545/// use spawned_concurrency::Response;
546/// use spawned_concurrency::protocol;
547///
548/// #[protocol]
549/// pub trait CounterProtocol: Send + Sync {
550///     fn increment(&self, amount: u64);                     // send (fire-and-forget)
551///     fn get_count(&self) -> Response<u64>;                 // request (both modes)
552/// }
553///
554/// // Generated:
555/// // - pub mod counter_protocol { pub struct Increment { pub amount: u64 }, pub struct GetCount }
556/// // - pub type CounterRef = Arc<dyn CounterProtocol>;
557/// // - pub trait ToCounterRef { fn to_counter_ref(&self) -> CounterRef; }
558/// // - impl CounterProtocol for ActorRef<A> where A: Handler<Increment> + Handler<GetCount>
559/// ```
560#[proc_macro_attribute]
561pub fn protocol(_attr: TokenStream, item: TokenStream) -> TokenStream {
562    let trait_def = parse_macro_input!(item as ItemTrait);
563    let trait_name = &trait_def.ident;
564    let trait_vis = &trait_def.vis;
565
566    if !trait_def.generics.params.is_empty() {
567        return syn::Error::new_spanned(
568            &trait_def.generics,
569            "generic type parameters on protocol traits are not supported",
570        )
571        .to_compile_error()
572        .into();
573    }
574
575    let base_name = strip_protocol_suffix(&trait_name.to_string());
576    let mod_name = format_ident!("{}", to_snake_case(&trait_name.to_string()));
577    let ref_name = format_ident!("{}Ref", base_name);
578    let converter_trait = format_ident!("To{}Ref", base_name);
579    let converter_method = format_ident!("to_{}_ref", to_snake_case(&base_name));
580
581    let mut methods: Vec<ProtocolMethodInfo> = Vec::new();
582
583    for item in &trait_def.items {
584        if !matches!(item, TraitItem::Fn(_)) {
585            return syn::Error::new_spanned(
586                item,
587                "protocol traits may only contain methods; \
588                 associated types, constants, and other items are not supported",
589            )
590            .to_compile_error()
591            .into();
592        }
593        if let TraitItem::Fn(method) = item {
594            if method.sig.asyncness.is_some() {
595                return syn::Error::new_spanned(
596                    &method.sig,
597                    "protocol methods must not be async; \
598                     use Response<T> as the return type for requests",
599                )
600                .to_compile_error()
601                .into();
602            }
603
604            // Verify first param is &self
605            match method.sig.inputs.first() {
606                Some(FnArg::Receiver(r)) if r.reference.is_some() && r.mutability.is_none() => {}
607                _ => {
608                    return syn::Error::new_spanned(
609                        &method.sig,
610                        "protocol methods must take `&self` as the first parameter",
611                    )
612                    .to_compile_error()
613                    .into();
614                }
615            }
616
617            let method_name = method.sig.ident.clone();
618            let struct_name = format_ident!("{}", to_pascal_case(&method_name.to_string()));
619
620            let mut field_names: Vec<Ident> = Vec::new();
621            let mut field_types: Vec<Type> = Vec::new();
622            let mut params: Vec<FnArg> = Vec::new();
623
624            for arg in method.sig.inputs.iter().skip(1) {
625                if let FnArg::Typed(pat_type) = arg {
626                    if let Pat::Ident(pat_ident) = &*pat_type.pat {
627                        field_names.push(pat_ident.ident.clone());
628                        field_types.push((*pat_type.ty).clone());
629                    } else {
630                        return syn::Error::new_spanned(
631                            &pat_type.pat,
632                            "protocol methods only support simple identifier patterns \
633                             (e.g., `name: Type`)",
634                        )
635                        .to_compile_error()
636                        .into();
637                    }
638                }
639                params.push(arg.clone());
640            }
641
642            let kind = match classify_return_type(&method.sig.output) {
643                Ok(kind) => kind,
644                Err(ty) => {
645                    return syn::Error::new_spanned(
646                        ty,
647                        "unsupported return type in protocol method; \
648                         use Response<T> for requests (works in both async and sync modes), \
649                         Result<(), ActorError> for sends, or no return type for fire-and-forget",
650                    )
651                    .to_compile_error()
652                    .into();
653                }
654            };
655
656            let method_attrs: Vec<Attribute> = method
657                .attrs
658                .iter()
659                .filter(|a| {
660                    a.path().is_ident("doc")
661                        || a.path().is_ident("cfg")
662                        || a.path().is_ident("cfg_attr")
663                })
664                .cloned()
665                .collect();
666            let cfg_attrs: Vec<Attribute> = method
667                .attrs
668                .iter()
669                .filter(|a| a.path().is_ident("cfg") || a.path().is_ident("cfg_attr"))
670                .cloned()
671                .collect();
672
673            methods.push(ProtocolMethodInfo {
674                method_name,
675                struct_name,
676                field_names,
677                field_types,
678                kind,
679                params,
680                ret_type: method.sig.output.clone(),
681                method_attrs,
682                cfg_attrs,
683            });
684        }
685    }
686
687    // Generate message structs
688    // Field types and result types are qualified with `super::` to prevent
689    // name collisions when a generated struct name shadows an imported type.
690    // e.g., `fn event(&self, event: Event)` generates `struct Event { pub event: super::Event }`
691    let msg_structs: Vec<_> = methods
692        .iter()
693        .map(|m| {
694            let struct_name = &m.struct_name;
695            let field_names = &m.field_names;
696            let qualified_field_types: Vec<Type> =
697                m.field_types.iter().map(qualify_type_with_super).collect();
698            let method_attrs = &m.method_attrs;
699            let cfg_attrs = &m.cfg_attrs;
700            let msg_result_ty: Type = match &m.kind {
701                MethodKind::Send => syn::parse_quote! { () },
702                MethodKind::Request(inner) => qualify_type_with_super(inner),
703            };
704
705            if field_names.is_empty() {
706                quote! {
707                    #(#method_attrs)*
708                    #[derive(Clone)]
709                    pub struct #struct_name;
710                    #(#cfg_attrs)*
711                    impl Message for #struct_name {
712                        type Result = #msg_result_ty;
713                    }
714                }
715            } else {
716                quote! {
717                    #(#method_attrs)*
718                    pub struct #struct_name {
719                        #(pub #field_names: #qualified_field_types,)*
720                    }
721                    #(#cfg_attrs)*
722                    impl Message for #struct_name {
723                        type Result = #msg_result_ty;
724                    }
725                }
726            }
727        })
728        .collect();
729
730    // Always generate blanket impls for both runtimes
731    let tasks = quote! { spawned_concurrency::tasks };
732    let threads = quote! { spawned_concurrency::threads };
733    let proto_info = ProtocolInfo {
734        trait_name,
735        mod_name: &mod_name,
736        ref_name: &ref_name,
737        converter_trait: &converter_trait,
738        converter_method: &converter_method,
739    };
740    let tasks_impl = generate_blanket_impl(&proto_info, &methods, &tasks, RuntimeMode::Tasks);
741    let threads_impl = generate_blanket_impl(&proto_info, &methods, &threads, RuntimeMode::Threads);
742    let blanket_impls = quote! { #tasks_impl #threads_impl };
743
744    let ref_doc = format!(
745        "Type-erased reference to any actor implementing [`{trait_name}`].\n\n\
746         Use this type to store protocol references without depending on the concrete actor type."
747    );
748
749    // Protocol traits are consumed via blanket impls and dyn references,
750    // so rustc may report methods as unused during development.
751    let output = quote! {
752        #[allow(dead_code)]
753        #trait_def
754
755        #[doc = #ref_doc]
756        #trait_vis type #ref_name = ::std::sync::Arc<dyn #trait_name>;
757
758        #trait_vis mod #mod_name {
759            use super::*;
760            use spawned_concurrency::message::Message;
761            #(#msg_structs)*
762        }
763
764        #trait_vis trait #converter_trait {
765            fn #converter_method(&self) -> #ref_name;
766        }
767
768        impl #converter_trait for #ref_name {
769            fn #converter_method(&self) -> #ref_name {
770                ::std::sync::Arc::clone(self)
771            }
772        }
773
774        #blanket_impls
775    };
776
777    output.into()
778}
779
780/// Generates `impl Actor` and `Handler<M>` implementations from an annotated impl block.
781///
782/// Place `#[actor]` on an `impl MyStruct` block. The macro extracts annotated methods
783/// and generates the boilerplate needed to run the struct as an actor.
784///
785/// # Protocol Assertion
786///
787/// Use `protocol = TraitName` to verify at compile time that the actor fully implements
788/// a protocol (i.e., that `ActorRef<Self>` implements the protocol trait):
789///
790/// ```ignore
791/// #[actor(protocol = NameServerProtocol)]
792/// impl NameServer { /* ... */ }
793/// ```
794///
795/// The runtime (tasks vs threads) is inferred from whether any handler in the block
796/// is `async fn`. If all handlers are defined outside the `#[actor]` block, put at
797/// least one handler inside so the macro can detect the correct runtime.
798///
799/// For multiple protocols:
800/// ```ignore
801/// #[actor(protocol(RoomProtocol, UserProtocol))]
802/// impl ChatUser { /* ... */ }
803/// ```
804///
805/// # Handler Attributes
806///
807/// | Attribute | Use for |
808/// |-----------|---------|
809/// | `#[request_handler]` | Messages that expect a reply (`Response<T>`) |
810/// | `#[send_handler]` | Fire-and-forget messages |
811/// | `#[handler]` | Generic handler (works for either kind) |
812///
813/// Handler signature: `fn name(&mut self, msg: MessageType, ctx: &Context<Self>) -> ReturnType`
814///
815/// The return type must match the message's `Message::Result`. For request handlers
816/// returning `()`, omit the return type.
817///
818/// # Lifecycle Hooks
819///
820/// | Attribute | Called |
821/// |-----------|--------|
822/// | `#[started]` | After the actor starts, before processing messages |
823/// | `#[stopped]` | After the actor stops processing messages |
824///
825/// Both receive `&mut self` and `&Context<Self>`. Use async or sync to match your runtime.
826///
827/// # Example
828///
829/// ```ignore
830/// use spawned_concurrency::tasks::{Actor, Context, Handler};
831/// use spawned_concurrency::actor;
832///
833/// pub struct MyActor { count: u64 }
834///
835/// #[actor(protocol = CounterProtocol)]
836/// impl MyActor {
837///     pub fn new() -> Self { MyActor { count: 0 } }
838///
839///     #[started]
840///     async fn on_start(&mut self, _ctx: &Context<Self>) {
841///         tracing::info!("Actor started");
842///     }
843///
844///     #[send_handler]
845///     async fn handle_increment(&mut self, msg: Increment, _ctx: &Context<Self>) {
846///         self.count += msg.amount;
847///     }
848///
849///     #[request_handler]
850///     async fn handle_get_count(&mut self, _msg: GetCount, _ctx: &Context<Self>) -> u64 {
851///         self.count
852///     }
853/// }
854/// ```
855#[proc_macro_attribute]
856pub fn actor(attr: TokenStream, item: TokenStream) -> TokenStream {
857    let mut impl_block = parse_macro_input!(item as ItemImpl);
858
859    let self_ty = &impl_block.self_ty;
860    let (impl_generics, _, where_clause) = impl_block.generics.split_for_impl();
861
862    // --- Parse named parameters from #[actor(protocol = X)] or #[actor(protocol(X, Y))] ---
863    let bridge_traits: Vec<Ident> = if attr.is_empty() {
864        Vec::new()
865    } else {
866        let parser = |input: syn::parse::ParseStream| -> syn::Result<Vec<Ident>> {
867            let mut protocols = Vec::new();
868            while !input.is_empty() {
869                let key: Ident = input.parse()?;
870                if key != "protocol" {
871                    return Err(syn::Error::new(
872                        key.span(),
873                        "unknown parameter, expected `protocol`",
874                    ));
875                }
876                if input.peek(syn::Token![=]) {
877                    // protocol = TraitName
878                    let _: syn::Token![=] = input.parse()?;
879                    protocols.push(input.parse()?);
880                } else {
881                    // protocol(Trait1, Trait2)
882                    let content;
883                    syn::parenthesized!(content in input);
884                    let punctuated = content.parse_terminated(Ident::parse, syn::Token![,])?;
885                    protocols.extend(punctuated);
886                }
887                if input.peek(syn::Token![,]) {
888                    let _: syn::Token![,] = input.parse()?;
889                }
890            }
891            Ok(protocols)
892        };
893        match syn::parse::Parser::parse(parser, attr) {
894            Ok(traits) => traits,
895            Err(e) => return e.to_compile_error().into(),
896        }
897    };
898
899    // --- Extract #[started] and #[stopped] lifecycle methods ---
900    let mut started_method: Option<ImplItemFn> = None;
901    let mut stopped_method: Option<ImplItemFn> = None;
902    let mut has_async = false;
903
904    let mut items_to_keep = Vec::new();
905    for item in impl_block.items.drain(..) {
906        if let ImplItem::Fn(ref method) = item {
907            let is_started = method.attrs.iter().any(|a| a.path().is_ident("started"));
908            let is_stopped = method.attrs.iter().any(|a| a.path().is_ident("stopped"));
909
910            if is_started {
911                if started_method.is_some() {
912                    return syn::Error::new_spanned(
913                        &method.sig,
914                        "only one #[started] method is allowed per actor",
915                    )
916                    .to_compile_error()
917                    .into();
918                }
919                if method.attrs.iter().any(|a| {
920                    a.path().is_ident("handler")
921                        || a.path().is_ident("send_handler")
922                        || a.path().is_ident("request_handler")
923                }) {
924                    return syn::Error::new_spanned(
925                        &method.sig,
926                        "#[started] cannot be combined with handler attributes",
927                    )
928                    .to_compile_error()
929                    .into();
930                }
931                // Expect: fn started(&mut self, ctx: &Context<Self>)
932                if method.sig.inputs.len() != 2 {
933                    return syn::Error::new_spanned(
934                        &method.sig,
935                        "#[started] method must take exactly (&mut self, &Context<Self>)",
936                    )
937                    .to_compile_error()
938                    .into();
939                }
940                if !matches!(method.sig.inputs.first(), Some(FnArg::Receiver(r)) if r.mutability.is_some())
941                {
942                    return syn::Error::new_spanned(
943                        &method.sig,
944                        "#[started] method's first parameter must be `&mut self`",
945                    )
946                    .to_compile_error()
947                    .into();
948                }
949                let mut m = method.clone();
950                m.attrs.retain(|a| !a.path().is_ident("started"));
951                m.vis = syn::Visibility::Inherited;
952                m.sig.ident = format_ident!("started");
953                if m.sig.asyncness.is_some() {
954                    has_async = true;
955                }
956                started_method = Some(m);
957                continue;
958            }
959
960            if is_stopped {
961                if stopped_method.is_some() {
962                    return syn::Error::new_spanned(
963                        &method.sig,
964                        "only one #[stopped] method is allowed per actor",
965                    )
966                    .to_compile_error()
967                    .into();
968                }
969                if method.attrs.iter().any(|a| {
970                    a.path().is_ident("handler")
971                        || a.path().is_ident("send_handler")
972                        || a.path().is_ident("request_handler")
973                }) {
974                    return syn::Error::new_spanned(
975                        &method.sig,
976                        "#[stopped] cannot be combined with handler attributes",
977                    )
978                    .to_compile_error()
979                    .into();
980                }
981                // Expect: fn stopped(&mut self, ctx: &Context<Self>)
982                if method.sig.inputs.len() != 2 {
983                    return syn::Error::new_spanned(
984                        &method.sig,
985                        "#[stopped] method must take exactly (&mut self, &Context<Self>)",
986                    )
987                    .to_compile_error()
988                    .into();
989                }
990                if !matches!(method.sig.inputs.first(), Some(FnArg::Receiver(r)) if r.mutability.is_some())
991                {
992                    return syn::Error::new_spanned(
993                        &method.sig,
994                        "#[stopped] method's first parameter must be `&mut self`",
995                    )
996                    .to_compile_error()
997                    .into();
998                }
999                let mut m = method.clone();
1000                m.attrs.retain(|a| !a.path().is_ident("stopped"));
1001                m.vis = syn::Visibility::Inherited;
1002                m.sig.ident = format_ident!("stopped");
1003                if m.sig.asyncness.is_some() {
1004                    has_async = true;
1005                }
1006                stopped_method = Some(m);
1007                continue;
1008            }
1009        }
1010        items_to_keep.push(item);
1011    }
1012    impl_block.items = items_to_keep;
1013
1014    // --- Process handler methods ---
1015    let mut handler_impls = Vec::new();
1016
1017    for item in &mut impl_block.items {
1018        if let ImplItem::Fn(method) = item {
1019            let handler_idx = method.attrs.iter().position(|attr| {
1020                attr.path().is_ident("handler")
1021                    || attr.path().is_ident("send_handler")
1022                    || attr.path().is_ident("request_handler")
1023            });
1024
1025            if let Some(idx) = handler_idx {
1026                method.attrs.remove(idx);
1027
1028                // Collect remaining attributes (e.g. #[cfg(...)]) to propagate
1029                // to the generated Handler impl block.
1030                let extra_attrs: Vec<_> = method
1031                    .attrs
1032                    .iter()
1033                    .filter(|a| {
1034                        !a.path().is_ident("handler")
1035                            && !a.path().is_ident("send_handler")
1036                            && !a.path().is_ident("request_handler")
1037                            && !a.path().is_ident("started")
1038                            && !a.path().is_ident("stopped")
1039                    })
1040                    .cloned()
1041                    .collect();
1042
1043                let method_name = &method.sig.ident;
1044                if method.sig.asyncness.is_some() {
1045                    has_async = true;
1046                }
1047
1048                // Validate handler has exactly 3 parameters: &mut self, msg, ctx
1049                let param_count = method.sig.inputs.len();
1050                if param_count != 3 {
1051                    return syn::Error::new_spanned(
1052                        &method.sig,
1053                        format!(
1054                            "handler method must have 3 parameters (&mut self, msg: M, ctx: &Context<Self>), found {param_count}"
1055                        ),
1056                    )
1057                    .to_compile_error()
1058                    .into();
1059                }
1060
1061                // Extract message type from 2nd parameter (index 1, after &mut self)
1062                let msg_ty = match method.sig.inputs.iter().nth(1) {
1063                    Some(FnArg::Typed(pat_type)) => &*pat_type.ty,
1064                    _ => {
1065                        return syn::Error::new_spanned(
1066                            &method.sig,
1067                            "handler method must have signature: fn(&mut self, msg: M, ctx: &Context<Self>) -> R",
1068                        )
1069                        .to_compile_error()
1070                        .into();
1071                    }
1072                };
1073
1074                // Extract return type (default to () if omitted)
1075                let ret_ty: Box<Type> = match &method.sig.output {
1076                    ReturnType::Default => syn::parse_quote! { () },
1077                    ReturnType::Type(_, ty) => ty.clone(),
1078                };
1079
1080                let handler_impl = if method.sig.asyncness.is_some() {
1081                    quote! {
1082                        #(#extra_attrs)*
1083                        impl #impl_generics Handler<#msg_ty> for #self_ty #where_clause {
1084                            async fn handle(&mut self, msg: #msg_ty, ctx: &Context<Self>) -> #ret_ty {
1085                                self.#method_name(msg, ctx).await
1086                            }
1087                        }
1088                    }
1089                } else {
1090                    quote! {
1091                        #(#extra_attrs)*
1092                        impl #impl_generics Handler<#msg_ty> for #self_ty #where_clause {
1093                            fn handle(&mut self, msg: #msg_ty, ctx: &Context<Self>) -> #ret_ty {
1094                                self.#method_name(msg, ctx)
1095                            }
1096                        }
1097                    }
1098                };
1099
1100                handler_impls.push(handler_impl);
1101            }
1102        }
1103    }
1104
1105    // --- Generate impl Actor ---
1106    let lifecycle_methods: Vec<&ImplItemFn> = [started_method.as_ref(), stopped_method.as_ref()]
1107        .into_iter()
1108        .flatten()
1109        .collect();
1110
1111    let protocol_doc = if bridge_traits.is_empty() {
1112        quote! {}
1113    } else {
1114        let lines: Vec<String> = bridge_traits.iter().map(|t| format!("- [`{t}`]")).collect();
1115        let doc_body = format!(
1116            "# Protocol\n\n\
1117             When started, `ActorRef<{ty}>` implements:\n\n\
1118             {lines}\n\n\
1119             See the protocol trait docs for the full API.",
1120            ty = quote!(#self_ty),
1121            lines = lines.join("\n"),
1122        );
1123        quote! { #[doc = #doc_body] }
1124    };
1125
1126    let actor_impl = quote! {
1127        #protocol_doc
1128        impl #impl_generics Actor for #self_ty #where_clause {
1129            #(#lifecycle_methods)*
1130        }
1131    };
1132
1133    // --- Generate bridge assertions ---
1134    let runtime_path = if has_async {
1135        quote! { spawned_concurrency::tasks }
1136    } else {
1137        quote! { spawned_concurrency::threads }
1138    };
1139
1140    let bridge_asserts: Vec<_> = bridge_traits
1141        .iter()
1142        .map(|trait_name| {
1143            quote! {
1144                const _: () = {
1145                    fn _assert_bridge<__T: #trait_name>() {}
1146                    fn _check() {
1147                        _assert_bridge::<#runtime_path::ActorRef<#self_ty>>();
1148                    }
1149                };
1150            }
1151        })
1152        .collect();
1153
1154    let output = quote! {
1155        #actor_impl
1156        #impl_block
1157        #(#handler_impls)*
1158        #(#bridge_asserts)*
1159    };
1160
1161    output.into()
1162}