Skip to main content

wacore_derive/
lib.rs

1//! Derive macros for wacore protocol types.
2//!
3//! This crate provides derive macros for implementing the `ProtocolNode` trait
4//! on structs that represent WhatsApp protocol nodes.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use wacore_derive::{ProtocolNode, WireEnum};
10//!
11//! /// A query request node.
12//! /// Wire format: `<query request="interactive"/>`
13//! #[derive(ProtocolNode)]
14//! #[protocol(tag = "query")]
15//! pub struct QueryRequest {
16//!     #[attr(name = "request", default = "interactive")]
17//!     pub request_type: String,
18//! }
19//!
20//! /// An enum with string representation.
21//! #[derive(WireEnum)]
22//! pub enum MemberAddMode {
23//!     #[wire = "admin_add"]
24//!     AdminAdd,
25//!     #[wire = "all_member_add"]
26//!     AllMemberAdd,
27//! }
28//! ```
29
30use proc_macro::TokenStream;
31use quote::quote;
32use syn::{Data, DeriveInput, Fields, parse_macro_input};
33
34/// Derive macro for implementing `ProtocolNode` on structs with attributes.
35///
36/// # Attributes
37///
38/// - `#[protocol(tag = "tagname")]` - Required. Specifies the XML tag name.
39/// - `#[attr(name = "attrname")]` - Marks a String field as an XML attribute.
40/// - `#[attr(name = "attrname", default = "value")]` - Attribute with default value.
41///   For `Option<String>` fields, a default always yields `Some(default)`.
42/// - `#[attr(name = "attrname", jid)]` - Marks a Jid field as a JID attribute (required).
43/// - `#[attr(name = "attrname", jid, optional)]` - Marks an Option<Jid> field as optional.
44/// - `#[attr(name = "attrname", string_enum)]` - Marks a field whose type derives `WireEnum` in unit-string mode (uses `as_str()`/`TryFrom`).
45/// - `#[attr(name = "attrname", u64)]` - Marks a u64 numeric attribute.
46/// - `#[attr(name = "attrname", u32)]` - Marks a u32 numeric attribute.
47///   Numeric fields can also be `Option<u64>` / `Option<u32>` for optional attributes.
48///
49/// # Example
50///
51/// ```ignore
52/// #[derive(ProtocolNode)]
53/// #[protocol(tag = "message")]
54/// pub struct MessageStanza {
55///     #[attr(name = "from", jid)]
56///     pub from: Jid,
57///     
58///     #[attr(name = "to", jid)]
59///     pub to: Jid,
60///     
61///     #[attr(name = "id")]
62///     pub id: String,
63///     
64///     #[attr(name = "sender_lid", jid, optional)]
65///     pub sender_lid: Option<Jid>,
66/// }
67/// ```
68#[proc_macro_derive(ProtocolNode, attributes(protocol, attr))]
69pub fn derive_protocol_node(input: TokenStream) -> TokenStream {
70    let input = parse_macro_input!(input as DeriveInput);
71
72    let name = &input.ident;
73
74    let tag = match extract_tag(&input.attrs) {
75        Ok(Some(tag)) => tag,
76        Ok(None) => {
77            return syn::Error::new_spanned(
78                &input.ident,
79                "ProtocolNode requires #[protocol(tag = \"...\")]",
80            )
81            .to_compile_error()
82            .into();
83        }
84        Err(e) => return e.to_compile_error().into(),
85    };
86
87    let fields = match &input.data {
88        Data::Struct(data) => match &data.fields {
89            Fields::Named(fields) => &fields.named,
90            Fields::Unit => return generate_empty_impl(name, &tag).into(),
91            _ => {
92                return syn::Error::new_spanned(
93                    &input.ident,
94                    "ProtocolNode only supports named fields or unit structs",
95                )
96                .to_compile_error()
97                .into();
98            }
99        },
100        _ => {
101            return syn::Error::new_spanned(
102                &input.ident,
103                "ProtocolNode can only be derived for structs",
104            )
105            .to_compile_error()
106            .into();
107        }
108    };
109
110    let mut attr_fields = Vec::with_capacity(fields.len());
111    for field in fields {
112        match extract_attr_info(field) {
113            Ok(Some(attr_info)) => attr_fields.push(attr_info),
114            Ok(None) => {}
115            Err(e) => return e.to_compile_error().into(),
116        }
117    }
118
119    let attr_setters: Vec<_> = attr_fields
120        .iter()
121        .map(|info| {
122            let field_ident = &info.field_ident;
123            let attr_name = &info.attr_name;
124
125            match (&info.attr_type, info.optional) {
126                (AttrType::Jid, true) => {
127                    quote! {
128                        if let Some(jid) = self.#field_ident {
129                            builder = builder.attr(#attr_name, jid);
130                        }
131                    }
132                }
133                (AttrType::Jid, false) => {
134                    quote! {
135                        builder = builder.attr(#attr_name, self.#field_ident);
136                    }
137                }
138                (AttrType::String, true) => {
139                    quote! {
140                        if let Some(s) = self.#field_ident {
141                            builder = builder.attr(#attr_name, s);
142                        }
143                    }
144                }
145                (AttrType::String, false) => {
146                    quote! {
147                        builder = builder.attr(#attr_name, self.#field_ident);
148                    }
149                }
150                (AttrType::StringEnum, true) => {
151                    quote! {
152                        if let Some(ref v) = self.#field_ident {
153                            builder = builder.attr(#attr_name, v.as_str());
154                        }
155                    }
156                }
157                (AttrType::StringEnum, false) => {
158                    quote! {
159                        builder = builder.attr(#attr_name, self.#field_ident.as_str());
160                    }
161                }
162                (AttrType::U64, true) | (AttrType::U32, true) => {
163                    quote! {
164                        if let Some(v) = self.#field_ident {
165                            builder = builder.attr(#attr_name, v);
166                        }
167                    }
168                }
169                (AttrType::U64, false) | (AttrType::U32, false) => {
170                    quote! {
171                        builder = builder.attr(#attr_name, self.#field_ident);
172                    }
173                }
174            }
175        })
176        .collect();
177
178    let field_parsers: Vec<_> = attr_fields
179        .iter()
180        .map(|info| {
181            let field_ident = &info.field_ident;
182            let attr_name = &info.attr_name;
183
184            match (&info.attr_type, info.optional, &info.default) {
185                (AttrType::Jid, false, _) => {
186                    quote! {
187                        #field_ident: node.attrs().optional_jid(#attr_name)
188                            .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
189                    }
190                }
191                (AttrType::Jid, true, _) => {
192                    quote! {
193                        #field_ident: node.attrs().optional_jid(#attr_name)
194                    }
195                }
196                (AttrType::String, false, Some(default)) => {
197                    quote! {
198                        #field_ident: node.attrs().optional_string(#attr_name)
199                            .map(|s| s.to_string())
200                            .unwrap_or_else(|| #default.to_string())
201                    }
202                }
203                (AttrType::String, false, None) => {
204                    quote! {
205                        #field_ident: node.attrs().required_string(#attr_name)?.to_string()
206                    }
207                }
208                (AttrType::String, true, Some(default)) => {
209                    quote! {
210                        #field_ident: node.attrs().optional_string(#attr_name)
211                            .map(|s| s.to_string())
212                            .or_else(|| Some(#default.to_string()))
213                    }
214                }
215                (AttrType::String, true, None) => {
216                    quote! {
217                        #field_ident: node.attrs().optional_string(#attr_name).map(|s| s.to_string())
218                    }
219                }
220                // StringEnum: parse using the `parse_string_enum` helper which tries TryFrom then From.
221                (AttrType::StringEnum, false, Some(default)) => {
222                    quote! {
223                        #field_ident: ::wacore::protocol::parse_string_enum(
224                            node.attrs().optional_string(#attr_name).as_deref().unwrap_or(#default)
225                        )?
226                    }
227                }
228                (AttrType::StringEnum, false, None) => {
229                    quote! {
230                        #field_ident: ::wacore::protocol::parse_string_enum(
231                            &node.attrs().optional_string(#attr_name)
232                                .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
233                        )?
234                    }
235                }
236                (AttrType::StringEnum, true, _) => {
237                    quote! {
238                        #field_ident: node.attrs().optional_string(#attr_name)
239                            .map(|s| ::wacore::protocol::parse_string_enum(&s))
240                            .transpose()?
241                    }
242                }
243                // Numeric types
244                (AttrType::U64, false, _) => {
245                    quote! {
246                        #field_ident: node.attrs().optional_u64(#attr_name)
247                            .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
248                    }
249                }
250                (AttrType::U64, true, _) => {
251                    quote! {
252                        #field_ident: node.attrs().optional_u64(#attr_name)
253                    }
254                }
255                (AttrType::U32, false, _) => {
256                    quote! {
257                        #field_ident: node.attrs().optional_u64(#attr_name)
258                            .map(|v| u32::try_from(v))
259                            .transpose()
260                            .map_err(|_| ::anyhow::anyhow!("attribute '{}' value exceeds u32::MAX", #attr_name))?
261                            .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
262                    }
263                }
264                (AttrType::U32, true, _) => {
265                    quote! {
266                        #field_ident: node.attrs().optional_u64(#attr_name)
267                            .map(|v| u32::try_from(v))
268                            .transpose()
269                            .map_err(|_| ::anyhow::anyhow!("attribute '{}' value exceeds u32::MAX", #attr_name))?
270                    }
271                }
272            }
273        })
274        .collect();
275
276    // Only generate Default impl if all fields have defaults or are optional or have Default impl
277    let all_have_defaults = attr_fields.iter().all(|info| {
278        info.default.is_some() || info.optional || matches!(info.attr_type, AttrType::StringEnum)
279    });
280
281    let default_impl = if all_have_defaults {
282        let default_fields: Vec<_> = attr_fields
283            .iter()
284            .map(|info| {
285                let field_ident = &info.field_ident;
286                match (&info.attr_type, info.optional, &info.default) {
287                    (_, true, Some(default)) => quote! { #field_ident: Some(#default.to_string()) },
288                    (_, true, None) => quote! { #field_ident: None },
289                    (AttrType::String, false, Some(default)) => {
290                        quote! { #field_ident: #default.to_string() }
291                    }
292                    (AttrType::StringEnum, false, Some(default)) => {
293                        quote! { #field_ident: ::wacore::protocol::parse_string_enum(#default)
294                        .expect("invalid default for StringEnum field") }
295                    }
296                    (AttrType::StringEnum, false, None) => {
297                        quote! { #field_ident: ::core::default::Default::default() }
298                    }
299                    _ => unreachable!("all_have_defaults check should prevent this branch"),
300                }
301            })
302            .collect();
303
304        quote! {
305            impl ::core::default::Default for #name {
306                fn default() -> Self {
307                    Self {
308                        #(#default_fields),*
309                    }
310                }
311            }
312        }
313    } else {
314        quote! {}
315    };
316
317    let expanded = quote! {
318        impl ::wacore::protocol::ProtocolNode for #name {
319            fn tag(&self) -> &'static str {
320                #tag
321            }
322
323            fn into_node(self) -> ::wacore_binary::node::Node {
324                let mut builder = ::wacore_binary::builder::NodeBuilder::new(#tag);
325                #(#attr_setters)*
326                builder.build()
327            }
328
329            fn try_from_node_ref(node: &::wacore_binary::node::NodeRef<'_>) -> ::anyhow::Result<Self> {
330                if node.tag != #tag {
331                    return Err(::anyhow::anyhow!("expected <{}>, got <{}>", #tag, node.tag));
332                }
333                Ok(Self {
334                    #(#field_parsers),*
335                })
336            }
337        }
338
339        #default_impl
340    };
341
342    expanded.into()
343}
344
345/// Derive macro for empty protocol nodes (tag only, no attributes).
346///
347/// # Attributes
348///
349/// - `#[protocol(tag = "tagname")]` - Required. Specifies the XML tag name.
350///
351/// # Example
352///
353/// ```ignore
354/// #[derive(EmptyNode)]
355/// #[protocol(tag = "participants")]
356/// pub struct ParticipantsRequest;
357/// ```
358#[proc_macro_derive(EmptyNode, attributes(protocol))]
359pub fn derive_empty_node(input: TokenStream) -> TokenStream {
360    let input = parse_macro_input!(input as DeriveInput);
361
362    let name = &input.ident;
363
364    let tag = match extract_tag(&input.attrs) {
365        Ok(Some(tag)) => tag,
366        Ok(None) => {
367            return syn::Error::new_spanned(
368                &input.ident,
369                "EmptyNode requires #[protocol(tag = \"...\")]",
370            )
371            .to_compile_error()
372            .into();
373        }
374        Err(e) => return e.to_compile_error().into(),
375    };
376
377    generate_empty_impl(name, &tag).into()
378}
379
380fn generate_empty_impl(name: &syn::Ident, tag: &str) -> proc_macro2::TokenStream {
381    quote! {
382        impl ::wacore::protocol::ProtocolNode for #name {
383            fn tag(&self) -> &'static str {
384                #tag
385            }
386
387            fn into_node(self) -> ::wacore_binary::node::Node {
388                ::wacore_binary::builder::NodeBuilder::new(#tag).build()
389            }
390
391            fn try_from_node_ref(node: &::wacore_binary::node::NodeRef<'_>) -> ::anyhow::Result<Self> {
392                if node.tag != #tag {
393                    return Err(::anyhow::anyhow!("expected <{}>, got <{}>", #tag, node.tag));
394                }
395                Ok(Self)
396            }
397        }
398
399        impl ::core::default::Default for #name {
400            fn default() -> Self {
401                Self
402            }
403        }
404    }
405}
406
407enum AttrType {
408    String,
409    Jid,
410    /// A type implementing StringEnum (has `as_str()` and `TryFrom<&str>` or `From<&str>`).
411    StringEnum,
412    /// A u64 numeric attribute.
413    U64,
414    /// A u32 numeric attribute.
415    U32,
416}
417
418struct AttrFieldInfo {
419    field_ident: syn::Ident,
420    attr_name: String,
421    attr_type: AttrType,
422    optional: bool,
423    default: Option<String>,
424}
425
426fn extract_tag(attrs: &[syn::Attribute]) -> Result<Option<String>, syn::Error> {
427    for attr in attrs {
428        if attr.path().is_ident("protocol") {
429            let mut tag = None;
430            attr.parse_nested_meta(|meta| {
431                if meta.path.is_ident("tag") {
432                    let value: syn::LitStr = meta.value()?.parse()?;
433                    tag = Some(value.value());
434                }
435                Ok(())
436            })?;
437            if tag.is_some() {
438                return Ok(tag);
439            }
440        }
441    }
442    Ok(None)
443}
444
445fn extract_attr_info(field: &syn::Field) -> Result<Option<AttrFieldInfo>, syn::Error> {
446    let field_ident = match field.ident.clone() {
447        Some(ident) => ident,
448        None => return Ok(None),
449    };
450
451    // Check if field type is Option<T>
452    let is_optional = is_option_type(&field.ty);
453
454    for attr in &field.attrs {
455        if attr.path().is_ident("attr") {
456            let mut attr_name = None;
457            let mut default = None;
458            let mut is_jid = false;
459            let mut is_string_enum = false;
460            let mut is_u64 = false;
461            let mut is_u32 = false;
462            let mut explicit_optional = false;
463
464            attr.parse_nested_meta(|meta| {
465                if meta.path.is_ident("name") {
466                    let value: syn::LitStr = meta.value()?.parse()?;
467                    attr_name = Some(value.value());
468                } else if meta.path.is_ident("default") {
469                    let value: syn::LitStr = meta.value()?.parse()?;
470                    default = Some(value.value());
471                } else if meta.path.is_ident("jid") {
472                    is_jid = true;
473                } else if meta.path.is_ident("string_enum") {
474                    is_string_enum = true;
475                } else if meta.path.is_ident("u64") {
476                    is_u64 = true;
477                } else if meta.path.is_ident("u32") {
478                    is_u32 = true;
479                } else if meta.path.is_ident("optional") {
480                    explicit_optional = true;
481                }
482                Ok(())
483            })?;
484
485            match attr_name {
486                Some(name) => {
487                    let attr_type = if is_jid {
488                        AttrType::Jid
489                    } else if is_string_enum {
490                        AttrType::StringEnum
491                    } else if is_u64 {
492                        AttrType::U64
493                    } else if is_u32 {
494                        AttrType::U32
495                    } else {
496                        AttrType::String
497                    };
498
499                    // Determine if optional: either explicit marker or Option<T> type
500                    let optional = explicit_optional || is_optional;
501
502                    return Ok(Some(AttrFieldInfo {
503                        field_ident,
504                        attr_name: name,
505                        attr_type,
506                        optional,
507                        default,
508                    }));
509                }
510                None => {
511                    return Err(syn::Error::new_spanned(
512                        attr,
513                        "missing required `name` in #[attr(...)]",
514                    ));
515                }
516            }
517        }
518    }
519    Ok(None)
520}
521
522/// Check if a type is Option<T>
523fn is_option_type(ty: &syn::Type) -> bool {
524    if let syn::Type::Path(type_path) = ty
525        && let Some(segment) = type_path.path.segments.last()
526    {
527        return segment.ident == "Option";
528    }
529    false
530}
531
532// =====================================================================
533// WireEnum — the unified replacement for StringEnum + manual impl Serialize
534// for tagged-with-payload and int-discriminated enums.
535//
536// Modes, inferred from attributes:
537//
538//   1. unit-string  (default when no #[wire(tag=...)] and no #[wire(kind="int")])
539//      Every variant is a unit (or a single #[wire_fallback] tuple with String).
540//      Emits: as_str, TryFrom<&str>/From<&str>, Default, Display, Serialize,
541//             Deserialize, ParseStringEnum. Drop-in replacement for StringEnum.
542//
543//   2. tagged       (enum has #[wire(tag = "type")])
544//      Variants carry payload (named fields or unit). One #[wire = "..."] per
545//      variant; optional #[wire_alias = "..."] adds parser-side aliases;
546//      #[wire(skip)] on a field excludes it from JSON; #[wire_fallback] with
547//      { tag: String } catches unknown tags.
548//      Emits: wire_tag(), impl Serialize (SerializeMap), and a sibling
549//             <Name>Tag unit enum (unit-string WireEnum) for parser dispatch.
550//      No Deserialize — follow-up work; not needed by current consumers.
551//
552//   3. int          (enum has #[wire(kind = "int")])
553//      Unit variants + optional #[wire_fallback] tuple with i32. Each variant
554//      has #[wire = NUM].
555//      Emits: code(), From<i32>, Serialize (as i32), Deserialize (from i32).
556//
557// The wire string/number lives exactly once per variant, in the #[wire = ...]
558// attribute. Everything else is derived.
559// =====================================================================
560
561#[proc_macro_derive(WireEnum, attributes(wire, wire_alias, wire_default, wire_fallback))]
562pub fn derive_wire_enum(input: TokenStream) -> TokenStream {
563    let input = parse_macro_input!(input as DeriveInput);
564
565    let variants = match &input.data {
566        Data::Enum(e) => e.variants.clone(),
567        _ => {
568            return syn::Error::new_spanned(&input.ident, "WireEnum can only be derived for enums")
569                .to_compile_error()
570                .into();
571        }
572    };
573
574    let cfg = match parse_enum_level_wire(&input.attrs) {
575        Ok(c) => c,
576        Err(e) => return e.to_compile_error().into(),
577    };
578
579    match cfg.kind {
580        WireKind::IntTagged => expand_wire_enum_int(&input.ident, &variants).into(),
581        WireKind::StringTagged(discriminator) => {
582            expand_wire_enum_tagged(&input.ident, &variants, &discriminator).into()
583        }
584        WireKind::UnitString => expand_wire_enum_unit(&input.ident, &variants).into(),
585    }
586}
587
588// ----- enum-level config -----
589
590enum WireKind {
591    UnitString,
592    StringTagged(String),
593    IntTagged,
594}
595
596struct WireEnumCfg {
597    kind: WireKind,
598}
599
600fn parse_enum_level_wire(attrs: &[syn::Attribute]) -> syn::Result<WireEnumCfg> {
601    let mut tag_field: Option<String> = None;
602    let mut kind_is_int = false;
603
604    for attr in attrs {
605        if !attr.path().is_ident("wire") {
606            continue;
607        }
608        attr.parse_nested_meta(|meta| {
609            if meta.path.is_ident("tag") {
610                let lit: syn::LitStr = meta.value()?.parse()?;
611                tag_field = Some(lit.value());
612            } else if meta.path.is_ident("kind") {
613                let lit: syn::LitStr = meta.value()?.parse()?;
614                match lit.value().as_str() {
615                    "int" => kind_is_int = true,
616                    "string" => kind_is_int = false,
617                    other => {
618                        return Err(meta.error(format!(
619                            "unknown wire kind {other:?}; expected \"string\" or \"int\""
620                        )));
621                    }
622                }
623            } else {
624                return Err(meta.error("unknown attribute inside #[wire(...)]"));
625            }
626            Ok(())
627        })?;
628    }
629
630    let kind = if kind_is_int {
631        if tag_field.is_some() {
632            return Err(syn::Error::new_spanned(
633                &attrs[0],
634                "#[wire(kind = \"int\")] is incompatible with #[wire(tag = \"...\")]",
635            ));
636        }
637        WireKind::IntTagged
638    } else if let Some(t) = tag_field {
639        WireKind::StringTagged(t)
640    } else {
641        WireKind::UnitString
642    };
643
644    Ok(WireEnumCfg { kind })
645}
646
647// ----- variant-level helpers -----
648
649enum VariantWire {
650    Str(String),
651    Int(i32),
652}
653
654struct VariantInfo {
655    ident: syn::Ident,
656    fields: syn::Fields,
657    wire: Option<VariantWire>,
658    aliases: Vec<String>,
659    is_default: bool,
660    is_fallback: bool,
661}
662
663fn read_variant(v: &syn::Variant) -> syn::Result<VariantInfo> {
664    let mut wire: Option<VariantWire> = None;
665    let mut aliases: Vec<String> = Vec::new();
666    let mut is_default = false;
667    let mut is_fallback = false;
668
669    for attr in &v.attrs {
670        if attr.path().is_ident("wire_default") {
671            is_default = true;
672        } else if attr.path().is_ident("wire_fallback") {
673            is_fallback = true;
674        } else if attr.path().is_ident("wire_alias") {
675            if let syn::Meta::NameValue(nv) = &attr.meta
676                && let syn::Expr::Lit(syn::ExprLit {
677                    lit: syn::Lit::Str(s),
678                    ..
679                }) = &nv.value
680            {
681                aliases.push(s.value());
682            } else {
683                return Err(syn::Error::new_spanned(
684                    attr,
685                    "expected #[wire_alias = \"...\"] with a string literal",
686                ));
687            }
688        } else if attr.path().is_ident("wire") {
689            // Variant-level #[wire = "..."] or #[wire = 101]
690            if let syn::Meta::NameValue(nv) = &attr.meta {
691                match &nv.value {
692                    syn::Expr::Lit(syn::ExprLit {
693                        lit: syn::Lit::Str(s),
694                        ..
695                    }) => wire = Some(VariantWire::Str(s.value())),
696                    syn::Expr::Lit(syn::ExprLit {
697                        lit: syn::Lit::Int(n),
698                        ..
699                    }) => {
700                        // Reject out-of-range literals at macro parse time rather
701                        // than silently wrapping with `as i32`.
702                        let parsed: i32 = n.base10_parse().map_err(|_| {
703                            syn::Error::new_spanned(
704                                n,
705                                format!(
706                                    "#[wire = {}] does not fit in i32 ({}..={})",
707                                    n,
708                                    i32::MIN,
709                                    i32::MAX
710                                ),
711                            )
712                        })?;
713                        wire = Some(VariantWire::Int(parsed));
714                    }
715                    _ => {
716                        return Err(syn::Error::new_spanned(
717                            &nv.value,
718                            "#[wire = ...] expects a string or integer literal",
719                        ));
720                    }
721                }
722            }
723        }
724    }
725
726    Ok(VariantInfo {
727        ident: v.ident.clone(),
728        fields: v.fields.clone(),
729        wire,
730        aliases,
731        is_default,
732        is_fallback,
733    })
734}
735
736fn field_has_wire_skip(attrs: &[syn::Attribute]) -> bool {
737    for attr in attrs {
738        if !attr.path().is_ident("wire") {
739            continue;
740        }
741        let mut found_skip = false;
742        let _ = attr.parse_nested_meta(|meta| {
743            if meta.path.is_ident("skip") {
744                found_skip = true;
745            }
746            Ok(())
747        });
748        if found_skip {
749            return true;
750        }
751    }
752    false
753}
754
755// ================== unit-string mode ==================
756
757fn expand_wire_enum_unit(
758    name: &syn::Ident,
759    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
760) -> proc_macro2::TokenStream {
761    let mut infos = Vec::with_capacity(variants.len());
762    for v in variants {
763        match read_variant(v) {
764            Ok(info) => infos.push(info),
765            Err(e) => return e.to_compile_error(),
766        }
767    }
768
769    let mut seen: std::collections::HashMap<String, syn::Ident> = Default::default();
770    let mut fallback: Option<&VariantInfo> = None;
771    let mut default_variant: Option<&VariantInfo> = None;
772
773    for info in &infos {
774        if info.is_fallback {
775            if fallback.is_some() {
776                return syn::Error::new_spanned(
777                    &info.ident,
778                    "only one #[wire_fallback] variant is allowed",
779                )
780                .to_compile_error();
781            }
782            match &info.fields {
783                syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {}
784                _ => {
785                    return syn::Error::new_spanned(
786                        &info.ident,
787                        "#[wire_fallback] on a unit-string enum requires VariantName(String)",
788                    )
789                    .to_compile_error();
790                }
791            }
792            if info.wire.is_some() {
793                return syn::Error::new_spanned(
794                    &info.ident,
795                    "#[wire_fallback] variant must not carry #[wire = \"...\"]",
796                )
797                .to_compile_error();
798            }
799            fallback = Some(info);
800            if info.is_default {
801                default_variant = Some(info);
802            }
803            continue;
804        }
805        if !matches!(info.fields, syn::Fields::Unit) {
806            return syn::Error::new_spanned(
807                &info.ident,
808                "unit-string WireEnum only supports unit variants (use #[wire_fallback] for a catch-all)",
809            )
810            .to_compile_error();
811        }
812        let Some(VariantWire::Str(s)) = &info.wire else {
813            return syn::Error::new_spanned(&info.ident, "variant needs #[wire = \"...\"]")
814                .to_compile_error();
815        };
816        if let Some(prev) = seen.insert(s.clone(), info.ident.clone()) {
817            return syn::Error::new_spanned(
818                &info.ident,
819                format!("duplicate #[wire = \"{s}\"]; already used by {prev}"),
820            )
821            .to_compile_error();
822        }
823        if info.is_default {
824            if default_variant.is_some() {
825                return syn::Error::new_spanned(&info.ident, "only one #[wire_default] is allowed")
826                    .to_compile_error();
827            }
828            default_variant = Some(info);
829        }
830        for alias in &info.aliases {
831            if let Some(prev) = seen.insert(alias.clone(), info.ident.clone()) {
832                return syn::Error::new_spanned(
833                    &info.ident,
834                    format!(
835                        "#[wire_alias = \"{alias}\"] collides with existing wire tag from variant {prev}"
836                    ),
837                )
838                .to_compile_error();
839            }
840        }
841    }
842
843    let first_known: Option<&VariantInfo> = infos.iter().find(|i| !i.is_fallback);
844    let default_info = match (default_variant, first_known, fallback) {
845        (Some(d), _, _) => d,
846        (None, Some(f), _) => f,
847        (None, None, Some(fb)) => fb,
848        (None, None, None) => {
849            return syn::Error::new_spanned(name, "WireEnum cannot be derived for empty enums")
850                .to_compile_error();
851        }
852    };
853    let default_ident = &default_info.ident;
854    let default_ctor = if default_info.is_fallback {
855        quote! { #name::#default_ident(::std::string::String::new()) }
856    } else {
857        quote! { #name::#default_ident }
858    };
859
860    let known: Vec<(&syn::Ident, &String)> = infos
861        .iter()
862        .filter(|i| !i.is_fallback)
863        .map(|i| {
864            let VariantWire::Str(s) = i.wire.as_ref().unwrap() else {
865                unreachable!()
866            };
867            (&i.ident, s)
868        })
869        .collect();
870
871    // `as_str()` always returns the PRIMARY tag — aliases are parser-only and
872    // must never surface in serialization.
873    let as_str_arms: Vec<_> = known
874        .iter()
875        .map(|(id, s)| quote! { #name::#id => #s })
876        .collect();
877
878    // For parsing, include primary + each alias; all map to the same variant.
879    let try_from_arms: Vec<proc_macro2::TokenStream> = infos
880        .iter()
881        .filter(|i| !i.is_fallback)
882        .flat_map(|i| {
883            let id = &i.ident;
884            let VariantWire::Str(primary) = i.wire.as_ref().unwrap() else {
885                unreachable!()
886            };
887            std::iter::once(primary.clone())
888                .chain(i.aliases.iter().cloned())
889                .map(move |s| quote! { #s => ::core::result::Result::Ok(#name::#id) })
890        })
891        .collect();
892
893    let from_arms: Vec<proc_macro2::TokenStream> = infos
894        .iter()
895        .filter(|i| !i.is_fallback)
896        .flat_map(|i| {
897            let id = &i.ident;
898            let VariantWire::Str(primary) = i.wire.as_ref().unwrap() else {
899                unreachable!()
900            };
901            std::iter::once(primary.clone())
902                .chain(i.aliases.iter().cloned())
903                .map(move |s| quote! { #s => #name::#id })
904        })
905        .collect();
906
907    let as_str_return_ty;
908    let as_str_block;
909    let conversion_impls;
910
911    if let Some(fb) = fallback {
912        let fb_ident = &fb.ident;
913        as_str_return_ty = quote! { &str };
914        as_str_block = quote! {
915            match self {
916                #(#as_str_arms,)*
917                #name::#fb_ident(s) => s.as_str(),
918            }
919        };
920        conversion_impls = quote! {
921            impl ::core::convert::From<&str> for #name {
922                fn from(value: &str) -> Self {
923                    match value {
924                        #(#from_arms,)*
925                        other => #name::#fb_ident(other.to_string()),
926                    }
927                }
928            }
929
930            impl ::wacore::protocol::ParseStringEnum for #name {
931                fn parse_from_str(s: &str) -> ::anyhow::Result<Self> {
932                    ::core::result::Result::Ok(::core::convert::From::from(s))
933                }
934            }
935        };
936    } else {
937        as_str_return_ty = quote! { &'static str };
938        as_str_block = quote! {
939            match self {
940                #(#as_str_arms),*
941            }
942        };
943        conversion_impls = quote! {
944            impl ::core::convert::TryFrom<&str> for #name {
945                type Error = ::anyhow::Error;
946                fn try_from(value: &str) -> ::core::result::Result<Self, Self::Error> {
947                    match value {
948                        #(#try_from_arms),*,
949                        _ => ::core::result::Result::Err(
950                            ::anyhow::anyhow!("unknown {}: {}", stringify!(#name), value)
951                        ),
952                    }
953                }
954            }
955
956            impl ::wacore::protocol::ParseStringEnum for #name {
957                fn parse_from_str(s: &str) -> ::anyhow::Result<Self> {
958                    ::core::convert::TryFrom::try_from(s)
959                }
960            }
961        };
962    }
963
964    let deserialize_impl = if fallback.is_some() {
965        quote! {
966            impl<'de> ::serde::Deserialize<'de> for #name {
967                fn deserialize<D: ::serde::Deserializer<'de>>(
968                    deserializer: D,
969                ) -> ::core::result::Result<Self, D::Error> {
970                    let s = <::std::string::String as ::serde::Deserialize>::deserialize(deserializer)?;
971                    ::core::result::Result::Ok(<Self as ::core::convert::From<&str>>::from(s.as_str()))
972                }
973            }
974        }
975    } else {
976        quote! {
977            impl<'de> ::serde::Deserialize<'de> for #name {
978                fn deserialize<D: ::serde::Deserializer<'de>>(
979                    deserializer: D,
980                ) -> ::core::result::Result<Self, D::Error> {
981                    let s = <::std::string::String as ::serde::Deserialize>::deserialize(deserializer)?;
982                    <Self as ::core::convert::TryFrom<&str>>::try_from(s.as_str())
983                        .map_err(|e| <D::Error as ::serde::de::Error>::custom(e.to_string()))
984                }
985            }
986        }
987    };
988
989    quote! {
990        impl #name {
991            /// Wire string for this variant (single source of truth).
992            pub fn as_str(&self) -> #as_str_return_ty {
993                #as_str_block
994            }
995        }
996
997        impl ::core::fmt::Display for #name {
998            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
999                f.write_str(self.as_str())
1000            }
1001        }
1002
1003        #conversion_impls
1004
1005        impl ::core::default::Default for #name {
1006            fn default() -> Self {
1007                #default_ctor
1008            }
1009        }
1010
1011        impl ::serde::Serialize for #name {
1012            fn serialize<S: ::serde::Serializer>(
1013                &self,
1014                serializer: S,
1015            ) -> ::core::result::Result<S::Ok, S::Error> {
1016                serializer.serialize_str(self.as_str())
1017            }
1018        }
1019
1020        #deserialize_impl
1021    }
1022}
1023
1024// ================== int mode ==================
1025
1026fn expand_wire_enum_int(
1027    name: &syn::Ident,
1028    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
1029) -> proc_macro2::TokenStream {
1030    let mut infos = Vec::with_capacity(variants.len());
1031    for v in variants {
1032        match read_variant(v) {
1033            Ok(info) => infos.push(info),
1034            Err(e) => return e.to_compile_error(),
1035        }
1036    }
1037
1038    let mut fallback: Option<&VariantInfo> = None;
1039    let mut seen: std::collections::HashMap<i32, syn::Ident> = Default::default();
1040
1041    for info in &infos {
1042        if info.is_fallback {
1043            if fallback.is_some() {
1044                return syn::Error::new_spanned(
1045                    &info.ident,
1046                    "only one #[wire_fallback] is allowed",
1047                )
1048                .to_compile_error();
1049            }
1050            match &info.fields {
1051                syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {}
1052                _ => {
1053                    return syn::Error::new_spanned(
1054                        &info.ident,
1055                        "#[wire_fallback] in int mode requires VariantName(i32)",
1056                    )
1057                    .to_compile_error();
1058                }
1059            }
1060            fallback = Some(info);
1061            continue;
1062        }
1063        if !matches!(info.fields, syn::Fields::Unit) {
1064            return syn::Error::new_spanned(
1065                &info.ident,
1066                "int-mode WireEnum variants must be unit variants (except the #[wire_fallback])",
1067            )
1068            .to_compile_error();
1069        }
1070        let Some(VariantWire::Int(n)) = &info.wire else {
1071            return syn::Error::new_spanned(&info.ident, "variant needs #[wire = NUMBER]")
1072                .to_compile_error();
1073        };
1074        if let Some(prev) = seen.insert(*n, info.ident.clone()) {
1075            return syn::Error::new_spanned(
1076                &info.ident,
1077                format!("duplicate #[wire = {n}]; already used by {prev}"),
1078            )
1079            .to_compile_error();
1080        }
1081    }
1082
1083    let Some(fb) = fallback else {
1084        return syn::Error::new_spanned(
1085            name,
1086            "int-mode WireEnum requires a #[wire_fallback] variant like Unknown(i32)",
1087        )
1088        .to_compile_error();
1089    };
1090    let fb_ident = &fb.ident;
1091
1092    let code_arms: Vec<_> = infos
1093        .iter()
1094        .filter(|i| !i.is_fallback)
1095        .map(|i| {
1096            let id = &i.ident;
1097            let VariantWire::Int(n) = i.wire.as_ref().unwrap() else {
1098                unreachable!()
1099            };
1100            let lit = proc_macro2::Literal::i32_suffixed(*n);
1101            quote! { #name::#id => #lit }
1102        })
1103        .collect();
1104
1105    let from_arms: Vec<_> = infos
1106        .iter()
1107        .filter(|i| !i.is_fallback)
1108        .map(|i| {
1109            let id = &i.ident;
1110            let VariantWire::Int(n) = i.wire.as_ref().unwrap() else {
1111                unreachable!()
1112            };
1113            let lit = proc_macro2::Literal::i32_suffixed(*n);
1114            quote! { #lit => #name::#id }
1115        })
1116        .collect();
1117
1118    quote! {
1119        impl #name {
1120            /// Numeric wire code for this variant (single source of truth).
1121            pub fn code(&self) -> i32 {
1122                match self {
1123                    #(#code_arms,)*
1124                    #name::#fb_ident(n) => *n,
1125                }
1126            }
1127        }
1128
1129        impl ::core::convert::From<i32> for #name {
1130            fn from(code: i32) -> Self {
1131                match code {
1132                    #(#from_arms,)*
1133                    other => #name::#fb_ident(other),
1134                }
1135            }
1136        }
1137
1138        impl ::serde::Serialize for #name {
1139            fn serialize<S: ::serde::Serializer>(
1140                &self,
1141                serializer: S,
1142            ) -> ::core::result::Result<S::Ok, S::Error> {
1143                serializer.serialize_i32(self.code())
1144            }
1145        }
1146
1147        impl<'de> ::serde::Deserialize<'de> for #name {
1148            fn deserialize<D: ::serde::Deserializer<'de>>(
1149                deserializer: D,
1150            ) -> ::core::result::Result<Self, D::Error> {
1151                let n = <i32 as ::serde::Deserialize>::deserialize(deserializer)?;
1152                ::core::result::Result::Ok(<Self as ::core::convert::From<i32>>::from(n))
1153            }
1154        }
1155    }
1156}
1157
1158// ================== tagged mode ==================
1159
1160fn expand_wire_enum_tagged(
1161    name: &syn::Ident,
1162    variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
1163    discriminator: &str,
1164) -> proc_macro2::TokenStream {
1165    let mut infos = Vec::with_capacity(variants.len());
1166    for v in variants {
1167        match read_variant(v) {
1168            Ok(info) => infos.push(info),
1169            Err(e) => return e.to_compile_error(),
1170        }
1171    }
1172
1173    let mut seen: std::collections::HashMap<String, syn::Ident> = Default::default();
1174    let mut fallback: Option<&VariantInfo> = None;
1175
1176    for info in &infos {
1177        if info.is_fallback {
1178            if fallback.is_some() {
1179                return syn::Error::new_spanned(
1180                    &info.ident,
1181                    "only one #[wire_fallback] is allowed",
1182                )
1183                .to_compile_error();
1184            }
1185            // Must be { tag: String }
1186            let ok = matches!(
1187                &info.fields,
1188                syn::Fields::Named(n)
1189                    if n.named.len() == 1
1190                        && n.named
1191                            .first()
1192                            .unwrap()
1193                            .ident
1194                            .as_ref()
1195                            .map(|i| i == "tag")
1196                            .unwrap_or(false)
1197            );
1198            if !ok {
1199                return syn::Error::new_spanned(
1200                    &info.ident,
1201                    "tagged #[wire_fallback] must have exactly { tag: String }",
1202                )
1203                .to_compile_error();
1204            }
1205            if info.wire.is_some() {
1206                return syn::Error::new_spanned(
1207                    &info.ident,
1208                    "#[wire_fallback] variant must not have #[wire = \"...\"]",
1209                )
1210                .to_compile_error();
1211            }
1212            fallback = Some(info);
1213            continue;
1214        }
1215        let Some(VariantWire::Str(s)) = &info.wire else {
1216            return syn::Error::new_spanned(&info.ident, "variant needs #[wire = \"...\"]")
1217                .to_compile_error();
1218        };
1219        if let Some(prev) = seen.insert(s.clone(), info.ident.clone()) {
1220            return syn::Error::new_spanned(
1221                &info.ident,
1222                format!("duplicate #[wire = \"{s}\"]; already used by {prev}"),
1223            )
1224            .to_compile_error();
1225        }
1226        for alias in &info.aliases {
1227            if let Some(prev) = seen.insert(alias.clone(), info.ident.clone()) {
1228                return syn::Error::new_spanned(
1229                    &info.ident,
1230                    format!(
1231                        "#[wire_alias = \"{alias}\"] collides with wire tag from variant {prev}"
1232                    ),
1233                )
1234                .to_compile_error();
1235            }
1236        }
1237    }
1238
1239    // --- wire_tag(&self) -> &str ---
1240
1241    let wire_tag_arms: Vec<_> = infos
1242        .iter()
1243        .map(|info| {
1244            let id = &info.ident;
1245            if info.is_fallback {
1246                // { tag: String } — return borrowed from the field
1247                quote! { #name::#id { tag } => tag.as_str() }
1248            } else {
1249                let VariantWire::Str(s) = info.wire.as_ref().unwrap() else {
1250                    unreachable!()
1251                };
1252                match &info.fields {
1253                    syn::Fields::Unit => quote! { #name::#id => #s },
1254                    syn::Fields::Named(_) => quote! { #name::#id { .. } => #s },
1255                    syn::Fields::Unnamed(_) => quote! { #name::#id(..) => #s },
1256                }
1257            }
1258        })
1259        .collect();
1260
1261    // --- Serialize arms ---
1262
1263    let serialize_arms: Vec<_> = infos
1264        .iter()
1265        .map(|info| {
1266            let id = &info.ident;
1267            if info.is_fallback {
1268                // Only the discriminator is written (already done before match).
1269                quote! { #name::#id { tag: _ } => {} }
1270            } else {
1271                match &info.fields {
1272                    syn::Fields::Unit => quote! { #name::#id => {} },
1273                    syn::Fields::Named(named) => {
1274                        let bindings: Vec<proc_macro2::TokenStream> = named
1275                            .named
1276                            .iter()
1277                            .map(|f| {
1278                                let id = f.ident.as_ref().unwrap();
1279                                if field_has_wire_skip(&f.attrs) {
1280                                    quote! { #id: _ }
1281                                } else {
1282                                    quote! { #id }
1283                                }
1284                            })
1285                            .collect();
1286                        let entries: Vec<proc_macro2::TokenStream> = named
1287                            .named
1288                            .iter()
1289                            .filter(|f| !field_has_wire_skip(&f.attrs))
1290                            .map(|f| {
1291                                let id = f.ident.as_ref().unwrap();
1292                                let key = id.to_string();
1293                                if is_option_type(&f.ty) {
1294                                    quote! {
1295                                        if let ::core::option::Option::Some(__v) = #id {
1296                                            ::serde::ser::SerializeMap::serialize_entry(
1297                                                &mut map, #key, __v
1298                                            )?;
1299                                        }
1300                                    }
1301                                } else {
1302                                    quote! {
1303                                        ::serde::ser::SerializeMap::serialize_entry(
1304                                            &mut map, #key, #id
1305                                        )?;
1306                                    }
1307                                }
1308                            })
1309                            .collect();
1310                        quote! {
1311                            #name::#id { #(#bindings),* } => {
1312                                #(#entries)*
1313                            }
1314                        }
1315                    }
1316                    syn::Fields::Unnamed(_) => {
1317                        quote! {
1318                            compile_error!("tagged WireEnum tuple variants are not supported — use named fields or unit");
1319                        }
1320                    }
1321                }
1322            }
1323        })
1324        .collect();
1325
1326    // --- Sibling <Name>Tag unit enum (unit-string WireEnum) ---
1327
1328    let tag_ident = quote::format_ident!("{}Tag", name);
1329
1330    let mut tag_variant_tokens: Vec<proc_macro2::TokenStream> = Vec::new();
1331    for info in &infos {
1332        let id = &info.ident;
1333        if info.is_fallback {
1334            tag_variant_tokens.push(quote! {
1335                #[doc = "Unknown wire tag — captured for forward compatibility."]
1336                #[wire_fallback]
1337                Unknown(::std::string::String)
1338            });
1339            continue;
1340        }
1341        let VariantWire::Str(primary) = info.wire.as_ref().unwrap() else {
1342            unreachable!()
1343        };
1344        // Primary tag + aliases collapse into ONE tag variant. The unit-string
1345        // WireEnum derive on the tag enum expands `#[wire_alias = "..."]` into
1346        // extra `From<&str>` arms pointing at the same variant, so parsers see
1347        // `Tag::Foo` regardless of whether the wire tag was the primary or an
1348        // alias.
1349        let alias_attrs = info.aliases.iter().map(|a| quote! { #[wire_alias = #a] });
1350        tag_variant_tokens.push(quote! {
1351            #[wire = #primary]
1352            #(#alias_attrs)*
1353            #id
1354        });
1355    }
1356
1357    // --- Final expansion ---
1358
1359    let discriminator_lit = discriminator;
1360
1361    quote! {
1362        impl #name {
1363            /// The wire tag this variant serializes as — the JSON discriminator
1364            /// and the exact tag the parser dispatches on.
1365            pub fn wire_tag(&self) -> &str {
1366                match self {
1367                    #(#wire_tag_arms,)*
1368                }
1369            }
1370
1371            /// Back-compat alias of [`Self::wire_tag`].
1372            #[inline]
1373            pub fn tag_name(&self) -> &str {
1374                self.wire_tag()
1375            }
1376        }
1377
1378        impl ::serde::Serialize for #name {
1379            fn serialize<S: ::serde::Serializer>(
1380                &self,
1381                serializer: S,
1382            ) -> ::core::result::Result<S::Ok, S::Error> {
1383                use ::serde::ser::SerializeMap;
1384                let mut map = serializer.serialize_map(None)?;
1385                ::serde::ser::SerializeMap::serialize_entry(
1386                    &mut map, #discriminator_lit, self.wire_tag()
1387                )?;
1388                match self {
1389                    #(#serialize_arms,)*
1390                }
1391                ::serde::ser::SerializeMap::end(map)
1392            }
1393        }
1394
1395        /// Sibling unit enum listing every canonical wire tag for parser
1396        /// dispatch. Primary wire tags and any `#[wire_alias]` entries all
1397        /// resolve to the same variant via `From<&str>`.
1398        #[doc = "Auto-generated by `#[derive(WireEnum)]`."]
1399        #[derive(Debug, Clone, PartialEq, Eq, ::wacore::WireEnum)]
1400        #[allow(clippy::enum_variant_names)]
1401        pub enum #tag_ident {
1402            #(#tag_variant_tokens,)*
1403        }
1404    }
1405}