1use proc_macro::TokenStream;
31use quote::quote;
32use syn::{Data, DeriveInput, Fields, parse_macro_input};
33
34#[proc_macro_derive(ProtocolNode, attributes(protocol, attr))]
69pub fn derive_protocol_node(input: TokenStream) -> TokenStream {
70 let input = parse_macro_input!(input as DeriveInput);
71
72 let name = &input.ident;
73
74 let tag = match extract_tag(&input.attrs) {
75 Ok(Some(tag)) => tag,
76 Ok(None) => {
77 return syn::Error::new_spanned(
78 &input.ident,
79 "ProtocolNode requires #[protocol(tag = \"...\")]",
80 )
81 .to_compile_error()
82 .into();
83 }
84 Err(e) => return e.to_compile_error().into(),
85 };
86
87 let fields = match &input.data {
88 Data::Struct(data) => match &data.fields {
89 Fields::Named(fields) => &fields.named,
90 Fields::Unit => return generate_empty_impl(name, &tag).into(),
91 _ => {
92 return syn::Error::new_spanned(
93 &input.ident,
94 "ProtocolNode only supports named fields or unit structs",
95 )
96 .to_compile_error()
97 .into();
98 }
99 },
100 _ => {
101 return syn::Error::new_spanned(
102 &input.ident,
103 "ProtocolNode can only be derived for structs",
104 )
105 .to_compile_error()
106 .into();
107 }
108 };
109
110 let mut attr_fields = Vec::with_capacity(fields.len());
111 for field in fields {
112 match extract_attr_info(field) {
113 Ok(Some(attr_info)) => attr_fields.push(attr_info),
114 Ok(None) => {}
115 Err(e) => return e.to_compile_error().into(),
116 }
117 }
118
119 let attr_setters: Vec<_> = attr_fields
120 .iter()
121 .map(|info| {
122 let field_ident = &info.field_ident;
123 let attr_name = &info.attr_name;
124
125 match (&info.attr_type, info.optional) {
126 (AttrType::Jid, true) => {
127 quote! {
128 if let Some(jid) = self.#field_ident {
129 builder = builder.attr(#attr_name, jid);
130 }
131 }
132 }
133 (AttrType::Jid, false) => {
134 quote! {
135 builder = builder.attr(#attr_name, self.#field_ident);
136 }
137 }
138 (AttrType::String, true) => {
139 quote! {
140 if let Some(s) = self.#field_ident {
141 builder = builder.attr(#attr_name, s);
142 }
143 }
144 }
145 (AttrType::String, false) => {
146 quote! {
147 builder = builder.attr(#attr_name, self.#field_ident);
148 }
149 }
150 (AttrType::StringEnum, true) => {
151 quote! {
152 if let Some(ref v) = self.#field_ident {
153 builder = builder.attr(#attr_name, v.as_str());
154 }
155 }
156 }
157 (AttrType::StringEnum, false) => {
158 quote! {
159 builder = builder.attr(#attr_name, self.#field_ident.as_str());
160 }
161 }
162 (AttrType::U64, true) | (AttrType::U32, true) => {
163 quote! {
164 if let Some(v) = self.#field_ident {
165 builder = builder.attr(#attr_name, v);
166 }
167 }
168 }
169 (AttrType::U64, false) | (AttrType::U32, false) => {
170 quote! {
171 builder = builder.attr(#attr_name, self.#field_ident);
172 }
173 }
174 }
175 })
176 .collect();
177
178 let field_parsers: Vec<_> = attr_fields
179 .iter()
180 .map(|info| {
181 let field_ident = &info.field_ident;
182 let attr_name = &info.attr_name;
183
184 match (&info.attr_type, info.optional, &info.default) {
185 (AttrType::Jid, false, _) => {
186 quote! {
187 #field_ident: node.attrs().optional_jid(#attr_name)
188 .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
189 }
190 }
191 (AttrType::Jid, true, _) => {
192 quote! {
193 #field_ident: node.attrs().optional_jid(#attr_name)
194 }
195 }
196 (AttrType::String, false, Some(default)) => {
197 quote! {
198 #field_ident: node.attrs().optional_string(#attr_name)
199 .map(|s| s.to_string())
200 .unwrap_or_else(|| #default.to_string())
201 }
202 }
203 (AttrType::String, false, None) => {
204 quote! {
205 #field_ident: node.attrs().required_string(#attr_name)?.to_string()
206 }
207 }
208 (AttrType::String, true, Some(default)) => {
209 quote! {
210 #field_ident: node.attrs().optional_string(#attr_name)
211 .map(|s| s.to_string())
212 .or_else(|| Some(#default.to_string()))
213 }
214 }
215 (AttrType::String, true, None) => {
216 quote! {
217 #field_ident: node.attrs().optional_string(#attr_name).map(|s| s.to_string())
218 }
219 }
220 (AttrType::StringEnum, false, Some(default)) => {
222 quote! {
223 #field_ident: ::wacore::protocol::parse_string_enum(
224 node.attrs().optional_string(#attr_name).as_deref().unwrap_or(#default)
225 )?
226 }
227 }
228 (AttrType::StringEnum, false, None) => {
229 quote! {
230 #field_ident: ::wacore::protocol::parse_string_enum(
231 &node.attrs().optional_string(#attr_name)
232 .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
233 )?
234 }
235 }
236 (AttrType::StringEnum, true, _) => {
237 quote! {
238 #field_ident: node.attrs().optional_string(#attr_name)
239 .map(|s| ::wacore::protocol::parse_string_enum(&s))
240 .transpose()?
241 }
242 }
243 (AttrType::U64, false, _) => {
245 quote! {
246 #field_ident: node.attrs().optional_u64(#attr_name)
247 .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
248 }
249 }
250 (AttrType::U64, true, _) => {
251 quote! {
252 #field_ident: node.attrs().optional_u64(#attr_name)
253 }
254 }
255 (AttrType::U32, false, _) => {
256 quote! {
257 #field_ident: node.attrs().optional_u64(#attr_name)
258 .map(|v| u32::try_from(v))
259 .transpose()
260 .map_err(|_| ::anyhow::anyhow!("attribute '{}' value exceeds u32::MAX", #attr_name))?
261 .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
262 }
263 }
264 (AttrType::U32, true, _) => {
265 quote! {
266 #field_ident: node.attrs().optional_u64(#attr_name)
267 .map(|v| u32::try_from(v))
268 .transpose()
269 .map_err(|_| ::anyhow::anyhow!("attribute '{}' value exceeds u32::MAX", #attr_name))?
270 }
271 }
272 }
273 })
274 .collect();
275
276 let all_have_defaults = attr_fields.iter().all(|info| {
278 info.default.is_some() || info.optional || matches!(info.attr_type, AttrType::StringEnum)
279 });
280
281 let default_impl = if all_have_defaults {
282 let default_fields: Vec<_> = attr_fields
283 .iter()
284 .map(|info| {
285 let field_ident = &info.field_ident;
286 match (&info.attr_type, info.optional, &info.default) {
287 (_, true, Some(default)) => quote! { #field_ident: Some(#default.to_string()) },
288 (_, true, None) => quote! { #field_ident: None },
289 (AttrType::String, false, Some(default)) => {
290 quote! { #field_ident: #default.to_string() }
291 }
292 (AttrType::StringEnum, false, Some(default)) => {
293 quote! { #field_ident: ::wacore::protocol::parse_string_enum(#default)
294 .expect("invalid default for StringEnum field") }
295 }
296 (AttrType::StringEnum, false, None) => {
297 quote! { #field_ident: ::core::default::Default::default() }
298 }
299 _ => unreachable!("all_have_defaults check should prevent this branch"),
300 }
301 })
302 .collect();
303
304 quote! {
305 impl ::core::default::Default for #name {
306 fn default() -> Self {
307 Self {
308 #(#default_fields),*
309 }
310 }
311 }
312 }
313 } else {
314 quote! {}
315 };
316
317 let expanded = quote! {
318 impl ::wacore::protocol::ProtocolNode for #name {
319 fn tag(&self) -> &'static str {
320 #tag
321 }
322
323 fn into_node(self) -> ::wacore_binary::node::Node {
324 let mut builder = ::wacore_binary::builder::NodeBuilder::new(#tag);
325 #(#attr_setters)*
326 builder.build()
327 }
328
329 fn try_from_node_ref(node: &::wacore_binary::node::NodeRef<'_>) -> ::anyhow::Result<Self> {
330 if node.tag != #tag {
331 return Err(::anyhow::anyhow!("expected <{}>, got <{}>", #tag, node.tag));
332 }
333 Ok(Self {
334 #(#field_parsers),*
335 })
336 }
337 }
338
339 #default_impl
340 };
341
342 expanded.into()
343}
344
345#[proc_macro_derive(EmptyNode, attributes(protocol))]
359pub fn derive_empty_node(input: TokenStream) -> TokenStream {
360 let input = parse_macro_input!(input as DeriveInput);
361
362 let name = &input.ident;
363
364 let tag = match extract_tag(&input.attrs) {
365 Ok(Some(tag)) => tag,
366 Ok(None) => {
367 return syn::Error::new_spanned(
368 &input.ident,
369 "EmptyNode requires #[protocol(tag = \"...\")]",
370 )
371 .to_compile_error()
372 .into();
373 }
374 Err(e) => return e.to_compile_error().into(),
375 };
376
377 generate_empty_impl(name, &tag).into()
378}
379
380fn generate_empty_impl(name: &syn::Ident, tag: &str) -> proc_macro2::TokenStream {
381 quote! {
382 impl ::wacore::protocol::ProtocolNode for #name {
383 fn tag(&self) -> &'static str {
384 #tag
385 }
386
387 fn into_node(self) -> ::wacore_binary::node::Node {
388 ::wacore_binary::builder::NodeBuilder::new(#tag).build()
389 }
390
391 fn try_from_node_ref(node: &::wacore_binary::node::NodeRef<'_>) -> ::anyhow::Result<Self> {
392 if node.tag != #tag {
393 return Err(::anyhow::anyhow!("expected <{}>, got <{}>", #tag, node.tag));
394 }
395 Ok(Self)
396 }
397 }
398
399 impl ::core::default::Default for #name {
400 fn default() -> Self {
401 Self
402 }
403 }
404 }
405}
406
407enum AttrType {
408 String,
409 Jid,
410 StringEnum,
412 U64,
414 U32,
416}
417
418struct AttrFieldInfo {
419 field_ident: syn::Ident,
420 attr_name: String,
421 attr_type: AttrType,
422 optional: bool,
423 default: Option<String>,
424}
425
426fn extract_tag(attrs: &[syn::Attribute]) -> Result<Option<String>, syn::Error> {
427 for attr in attrs {
428 if attr.path().is_ident("protocol") {
429 let mut tag = None;
430 attr.parse_nested_meta(|meta| {
431 if meta.path.is_ident("tag") {
432 let value: syn::LitStr = meta.value()?.parse()?;
433 tag = Some(value.value());
434 }
435 Ok(())
436 })?;
437 if tag.is_some() {
438 return Ok(tag);
439 }
440 }
441 }
442 Ok(None)
443}
444
445fn extract_attr_info(field: &syn::Field) -> Result<Option<AttrFieldInfo>, syn::Error> {
446 let field_ident = match field.ident.clone() {
447 Some(ident) => ident,
448 None => return Ok(None),
449 };
450
451 let is_optional = is_option_type(&field.ty);
453
454 for attr in &field.attrs {
455 if attr.path().is_ident("attr") {
456 let mut attr_name = None;
457 let mut default = None;
458 let mut is_jid = false;
459 let mut is_string_enum = false;
460 let mut is_u64 = false;
461 let mut is_u32 = false;
462 let mut explicit_optional = false;
463
464 attr.parse_nested_meta(|meta| {
465 if meta.path.is_ident("name") {
466 let value: syn::LitStr = meta.value()?.parse()?;
467 attr_name = Some(value.value());
468 } else if meta.path.is_ident("default") {
469 let value: syn::LitStr = meta.value()?.parse()?;
470 default = Some(value.value());
471 } else if meta.path.is_ident("jid") {
472 is_jid = true;
473 } else if meta.path.is_ident("string_enum") {
474 is_string_enum = true;
475 } else if meta.path.is_ident("u64") {
476 is_u64 = true;
477 } else if meta.path.is_ident("u32") {
478 is_u32 = true;
479 } else if meta.path.is_ident("optional") {
480 explicit_optional = true;
481 }
482 Ok(())
483 })?;
484
485 match attr_name {
486 Some(name) => {
487 let attr_type = if is_jid {
488 AttrType::Jid
489 } else if is_string_enum {
490 AttrType::StringEnum
491 } else if is_u64 {
492 AttrType::U64
493 } else if is_u32 {
494 AttrType::U32
495 } else {
496 AttrType::String
497 };
498
499 let optional = explicit_optional || is_optional;
501
502 return Ok(Some(AttrFieldInfo {
503 field_ident,
504 attr_name: name,
505 attr_type,
506 optional,
507 default,
508 }));
509 }
510 None => {
511 return Err(syn::Error::new_spanned(
512 attr,
513 "missing required `name` in #[attr(...)]",
514 ));
515 }
516 }
517 }
518 }
519 Ok(None)
520}
521
522fn is_option_type(ty: &syn::Type) -> bool {
524 if let syn::Type::Path(type_path) = ty
525 && let Some(segment) = type_path.path.segments.last()
526 {
527 return segment.ident == "Option";
528 }
529 false
530}
531
532#[proc_macro_derive(WireEnum, attributes(wire, wire_alias, wire_default, wire_fallback))]
562pub fn derive_wire_enum(input: TokenStream) -> TokenStream {
563 let input = parse_macro_input!(input as DeriveInput);
564
565 let variants = match &input.data {
566 Data::Enum(e) => e.variants.clone(),
567 _ => {
568 return syn::Error::new_spanned(&input.ident, "WireEnum can only be derived for enums")
569 .to_compile_error()
570 .into();
571 }
572 };
573
574 let cfg = match parse_enum_level_wire(&input.attrs) {
575 Ok(c) => c,
576 Err(e) => return e.to_compile_error().into(),
577 };
578
579 match cfg.kind {
580 WireKind::IntTagged => expand_wire_enum_int(&input.ident, &variants).into(),
581 WireKind::StringTagged(discriminator) => {
582 expand_wire_enum_tagged(&input.ident, &variants, &discriminator).into()
583 }
584 WireKind::UnitString => expand_wire_enum_unit(&input.ident, &variants).into(),
585 }
586}
587
588enum WireKind {
591 UnitString,
592 StringTagged(String),
593 IntTagged,
594}
595
596struct WireEnumCfg {
597 kind: WireKind,
598}
599
600fn parse_enum_level_wire(attrs: &[syn::Attribute]) -> syn::Result<WireEnumCfg> {
601 let mut tag_field: Option<String> = None;
602 let mut kind_is_int = false;
603
604 for attr in attrs {
605 if !attr.path().is_ident("wire") {
606 continue;
607 }
608 attr.parse_nested_meta(|meta| {
609 if meta.path.is_ident("tag") {
610 let lit: syn::LitStr = meta.value()?.parse()?;
611 tag_field = Some(lit.value());
612 } else if meta.path.is_ident("kind") {
613 let lit: syn::LitStr = meta.value()?.parse()?;
614 match lit.value().as_str() {
615 "int" => kind_is_int = true,
616 "string" => kind_is_int = false,
617 other => {
618 return Err(meta.error(format!(
619 "unknown wire kind {other:?}; expected \"string\" or \"int\""
620 )));
621 }
622 }
623 } else {
624 return Err(meta.error("unknown attribute inside #[wire(...)]"));
625 }
626 Ok(())
627 })?;
628 }
629
630 let kind = if kind_is_int {
631 if tag_field.is_some() {
632 return Err(syn::Error::new_spanned(
633 &attrs[0],
634 "#[wire(kind = \"int\")] is incompatible with #[wire(tag = \"...\")]",
635 ));
636 }
637 WireKind::IntTagged
638 } else if let Some(t) = tag_field {
639 WireKind::StringTagged(t)
640 } else {
641 WireKind::UnitString
642 };
643
644 Ok(WireEnumCfg { kind })
645}
646
647enum VariantWire {
650 Str(String),
651 Int(i32),
652}
653
654struct VariantInfo {
655 ident: syn::Ident,
656 fields: syn::Fields,
657 wire: Option<VariantWire>,
658 aliases: Vec<String>,
659 is_default: bool,
660 is_fallback: bool,
661}
662
663fn read_variant(v: &syn::Variant) -> syn::Result<VariantInfo> {
664 let mut wire: Option<VariantWire> = None;
665 let mut aliases: Vec<String> = Vec::new();
666 let mut is_default = false;
667 let mut is_fallback = false;
668
669 for attr in &v.attrs {
670 if attr.path().is_ident("wire_default") {
671 is_default = true;
672 } else if attr.path().is_ident("wire_fallback") {
673 is_fallback = true;
674 } else if attr.path().is_ident("wire_alias") {
675 if let syn::Meta::NameValue(nv) = &attr.meta
676 && let syn::Expr::Lit(syn::ExprLit {
677 lit: syn::Lit::Str(s),
678 ..
679 }) = &nv.value
680 {
681 aliases.push(s.value());
682 } else {
683 return Err(syn::Error::new_spanned(
684 attr,
685 "expected #[wire_alias = \"...\"] with a string literal",
686 ));
687 }
688 } else if attr.path().is_ident("wire") {
689 if let syn::Meta::NameValue(nv) = &attr.meta {
691 match &nv.value {
692 syn::Expr::Lit(syn::ExprLit {
693 lit: syn::Lit::Str(s),
694 ..
695 }) => wire = Some(VariantWire::Str(s.value())),
696 syn::Expr::Lit(syn::ExprLit {
697 lit: syn::Lit::Int(n),
698 ..
699 }) => {
700 let parsed: i32 = n.base10_parse().map_err(|_| {
703 syn::Error::new_spanned(
704 n,
705 format!(
706 "#[wire = {}] does not fit in i32 ({}..={})",
707 n,
708 i32::MIN,
709 i32::MAX
710 ),
711 )
712 })?;
713 wire = Some(VariantWire::Int(parsed));
714 }
715 _ => {
716 return Err(syn::Error::new_spanned(
717 &nv.value,
718 "#[wire = ...] expects a string or integer literal",
719 ));
720 }
721 }
722 }
723 }
724 }
725
726 Ok(VariantInfo {
727 ident: v.ident.clone(),
728 fields: v.fields.clone(),
729 wire,
730 aliases,
731 is_default,
732 is_fallback,
733 })
734}
735
736fn field_has_wire_skip(attrs: &[syn::Attribute]) -> bool {
737 for attr in attrs {
738 if !attr.path().is_ident("wire") {
739 continue;
740 }
741 let mut found_skip = false;
742 let _ = attr.parse_nested_meta(|meta| {
743 if meta.path.is_ident("skip") {
744 found_skip = true;
745 }
746 Ok(())
747 });
748 if found_skip {
749 return true;
750 }
751 }
752 false
753}
754
755fn expand_wire_enum_unit(
758 name: &syn::Ident,
759 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
760) -> proc_macro2::TokenStream {
761 let mut infos = Vec::with_capacity(variants.len());
762 for v in variants {
763 match read_variant(v) {
764 Ok(info) => infos.push(info),
765 Err(e) => return e.to_compile_error(),
766 }
767 }
768
769 let mut seen: std::collections::HashMap<String, syn::Ident> = Default::default();
770 let mut fallback: Option<&VariantInfo> = None;
771 let mut default_variant: Option<&VariantInfo> = None;
772
773 for info in &infos {
774 if info.is_fallback {
775 if fallback.is_some() {
776 return syn::Error::new_spanned(
777 &info.ident,
778 "only one #[wire_fallback] variant is allowed",
779 )
780 .to_compile_error();
781 }
782 match &info.fields {
783 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {}
784 _ => {
785 return syn::Error::new_spanned(
786 &info.ident,
787 "#[wire_fallback] on a unit-string enum requires VariantName(String)",
788 )
789 .to_compile_error();
790 }
791 }
792 if info.wire.is_some() {
793 return syn::Error::new_spanned(
794 &info.ident,
795 "#[wire_fallback] variant must not carry #[wire = \"...\"]",
796 )
797 .to_compile_error();
798 }
799 fallback = Some(info);
800 if info.is_default {
801 default_variant = Some(info);
802 }
803 continue;
804 }
805 if !matches!(info.fields, syn::Fields::Unit) {
806 return syn::Error::new_spanned(
807 &info.ident,
808 "unit-string WireEnum only supports unit variants (use #[wire_fallback] for a catch-all)",
809 )
810 .to_compile_error();
811 }
812 let Some(VariantWire::Str(s)) = &info.wire else {
813 return syn::Error::new_spanned(&info.ident, "variant needs #[wire = \"...\"]")
814 .to_compile_error();
815 };
816 if let Some(prev) = seen.insert(s.clone(), info.ident.clone()) {
817 return syn::Error::new_spanned(
818 &info.ident,
819 format!("duplicate #[wire = \"{s}\"]; already used by {prev}"),
820 )
821 .to_compile_error();
822 }
823 if info.is_default {
824 if default_variant.is_some() {
825 return syn::Error::new_spanned(&info.ident, "only one #[wire_default] is allowed")
826 .to_compile_error();
827 }
828 default_variant = Some(info);
829 }
830 for alias in &info.aliases {
831 if let Some(prev) = seen.insert(alias.clone(), info.ident.clone()) {
832 return syn::Error::new_spanned(
833 &info.ident,
834 format!(
835 "#[wire_alias = \"{alias}\"] collides with existing wire tag from variant {prev}"
836 ),
837 )
838 .to_compile_error();
839 }
840 }
841 }
842
843 let first_known: Option<&VariantInfo> = infos.iter().find(|i| !i.is_fallback);
844 let default_info = match (default_variant, first_known, fallback) {
845 (Some(d), _, _) => d,
846 (None, Some(f), _) => f,
847 (None, None, Some(fb)) => fb,
848 (None, None, None) => {
849 return syn::Error::new_spanned(name, "WireEnum cannot be derived for empty enums")
850 .to_compile_error();
851 }
852 };
853 let default_ident = &default_info.ident;
854 let default_ctor = if default_info.is_fallback {
855 quote! { #name::#default_ident(::std::string::String::new()) }
856 } else {
857 quote! { #name::#default_ident }
858 };
859
860 let known: Vec<(&syn::Ident, &String)> = infos
861 .iter()
862 .filter(|i| !i.is_fallback)
863 .map(|i| {
864 let VariantWire::Str(s) = i.wire.as_ref().unwrap() else {
865 unreachable!()
866 };
867 (&i.ident, s)
868 })
869 .collect();
870
871 let as_str_arms: Vec<_> = known
874 .iter()
875 .map(|(id, s)| quote! { #name::#id => #s })
876 .collect();
877
878 let try_from_arms: Vec<proc_macro2::TokenStream> = infos
880 .iter()
881 .filter(|i| !i.is_fallback)
882 .flat_map(|i| {
883 let id = &i.ident;
884 let VariantWire::Str(primary) = i.wire.as_ref().unwrap() else {
885 unreachable!()
886 };
887 std::iter::once(primary.clone())
888 .chain(i.aliases.iter().cloned())
889 .map(move |s| quote! { #s => ::core::result::Result::Ok(#name::#id) })
890 })
891 .collect();
892
893 let from_arms: Vec<proc_macro2::TokenStream> = infos
894 .iter()
895 .filter(|i| !i.is_fallback)
896 .flat_map(|i| {
897 let id = &i.ident;
898 let VariantWire::Str(primary) = i.wire.as_ref().unwrap() else {
899 unreachable!()
900 };
901 std::iter::once(primary.clone())
902 .chain(i.aliases.iter().cloned())
903 .map(move |s| quote! { #s => #name::#id })
904 })
905 .collect();
906
907 let as_str_return_ty;
908 let as_str_block;
909 let conversion_impls;
910
911 if let Some(fb) = fallback {
912 let fb_ident = &fb.ident;
913 as_str_return_ty = quote! { &str };
914 as_str_block = quote! {
915 match self {
916 #(#as_str_arms,)*
917 #name::#fb_ident(s) => s.as_str(),
918 }
919 };
920 conversion_impls = quote! {
921 impl ::core::convert::From<&str> for #name {
922 fn from(value: &str) -> Self {
923 match value {
924 #(#from_arms,)*
925 other => #name::#fb_ident(other.to_string()),
926 }
927 }
928 }
929
930 impl ::wacore::protocol::ParseStringEnum for #name {
931 fn parse_from_str(s: &str) -> ::anyhow::Result<Self> {
932 ::core::result::Result::Ok(::core::convert::From::from(s))
933 }
934 }
935 };
936 } else {
937 as_str_return_ty = quote! { &'static str };
938 as_str_block = quote! {
939 match self {
940 #(#as_str_arms),*
941 }
942 };
943 conversion_impls = quote! {
944 impl ::core::convert::TryFrom<&str> for #name {
945 type Error = ::anyhow::Error;
946 fn try_from(value: &str) -> ::core::result::Result<Self, Self::Error> {
947 match value {
948 #(#try_from_arms),*,
949 _ => ::core::result::Result::Err(
950 ::anyhow::anyhow!("unknown {}: {}", stringify!(#name), value)
951 ),
952 }
953 }
954 }
955
956 impl ::wacore::protocol::ParseStringEnum for #name {
957 fn parse_from_str(s: &str) -> ::anyhow::Result<Self> {
958 ::core::convert::TryFrom::try_from(s)
959 }
960 }
961 };
962 }
963
964 let deserialize_impl = if fallback.is_some() {
965 quote! {
966 impl<'de> ::serde::Deserialize<'de> for #name {
967 fn deserialize<D: ::serde::Deserializer<'de>>(
968 deserializer: D,
969 ) -> ::core::result::Result<Self, D::Error> {
970 let s = <::std::string::String as ::serde::Deserialize>::deserialize(deserializer)?;
971 ::core::result::Result::Ok(<Self as ::core::convert::From<&str>>::from(s.as_str()))
972 }
973 }
974 }
975 } else {
976 quote! {
977 impl<'de> ::serde::Deserialize<'de> for #name {
978 fn deserialize<D: ::serde::Deserializer<'de>>(
979 deserializer: D,
980 ) -> ::core::result::Result<Self, D::Error> {
981 let s = <::std::string::String as ::serde::Deserialize>::deserialize(deserializer)?;
982 <Self as ::core::convert::TryFrom<&str>>::try_from(s.as_str())
983 .map_err(|e| <D::Error as ::serde::de::Error>::custom(e.to_string()))
984 }
985 }
986 }
987 };
988
989 quote! {
990 impl #name {
991 pub fn as_str(&self) -> #as_str_return_ty {
993 #as_str_block
994 }
995 }
996
997 impl ::core::fmt::Display for #name {
998 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
999 f.write_str(self.as_str())
1000 }
1001 }
1002
1003 #conversion_impls
1004
1005 impl ::core::default::Default for #name {
1006 fn default() -> Self {
1007 #default_ctor
1008 }
1009 }
1010
1011 impl ::serde::Serialize for #name {
1012 fn serialize<S: ::serde::Serializer>(
1013 &self,
1014 serializer: S,
1015 ) -> ::core::result::Result<S::Ok, S::Error> {
1016 serializer.serialize_str(self.as_str())
1017 }
1018 }
1019
1020 #deserialize_impl
1021 }
1022}
1023
1024fn expand_wire_enum_int(
1027 name: &syn::Ident,
1028 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
1029) -> proc_macro2::TokenStream {
1030 let mut infos = Vec::with_capacity(variants.len());
1031 for v in variants {
1032 match read_variant(v) {
1033 Ok(info) => infos.push(info),
1034 Err(e) => return e.to_compile_error(),
1035 }
1036 }
1037
1038 let mut fallback: Option<&VariantInfo> = None;
1039 let mut seen: std::collections::HashMap<i32, syn::Ident> = Default::default();
1040
1041 for info in &infos {
1042 if info.is_fallback {
1043 if fallback.is_some() {
1044 return syn::Error::new_spanned(
1045 &info.ident,
1046 "only one #[wire_fallback] is allowed",
1047 )
1048 .to_compile_error();
1049 }
1050 match &info.fields {
1051 syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {}
1052 _ => {
1053 return syn::Error::new_spanned(
1054 &info.ident,
1055 "#[wire_fallback] in int mode requires VariantName(i32)",
1056 )
1057 .to_compile_error();
1058 }
1059 }
1060 fallback = Some(info);
1061 continue;
1062 }
1063 if !matches!(info.fields, syn::Fields::Unit) {
1064 return syn::Error::new_spanned(
1065 &info.ident,
1066 "int-mode WireEnum variants must be unit variants (except the #[wire_fallback])",
1067 )
1068 .to_compile_error();
1069 }
1070 let Some(VariantWire::Int(n)) = &info.wire else {
1071 return syn::Error::new_spanned(&info.ident, "variant needs #[wire = NUMBER]")
1072 .to_compile_error();
1073 };
1074 if let Some(prev) = seen.insert(*n, info.ident.clone()) {
1075 return syn::Error::new_spanned(
1076 &info.ident,
1077 format!("duplicate #[wire = {n}]; already used by {prev}"),
1078 )
1079 .to_compile_error();
1080 }
1081 }
1082
1083 let Some(fb) = fallback else {
1084 return syn::Error::new_spanned(
1085 name,
1086 "int-mode WireEnum requires a #[wire_fallback] variant like Unknown(i32)",
1087 )
1088 .to_compile_error();
1089 };
1090 let fb_ident = &fb.ident;
1091
1092 let code_arms: Vec<_> = infos
1093 .iter()
1094 .filter(|i| !i.is_fallback)
1095 .map(|i| {
1096 let id = &i.ident;
1097 let VariantWire::Int(n) = i.wire.as_ref().unwrap() else {
1098 unreachable!()
1099 };
1100 let lit = proc_macro2::Literal::i32_suffixed(*n);
1101 quote! { #name::#id => #lit }
1102 })
1103 .collect();
1104
1105 let from_arms: Vec<_> = infos
1106 .iter()
1107 .filter(|i| !i.is_fallback)
1108 .map(|i| {
1109 let id = &i.ident;
1110 let VariantWire::Int(n) = i.wire.as_ref().unwrap() else {
1111 unreachable!()
1112 };
1113 let lit = proc_macro2::Literal::i32_suffixed(*n);
1114 quote! { #lit => #name::#id }
1115 })
1116 .collect();
1117
1118 quote! {
1119 impl #name {
1120 pub fn code(&self) -> i32 {
1122 match self {
1123 #(#code_arms,)*
1124 #name::#fb_ident(n) => *n,
1125 }
1126 }
1127 }
1128
1129 impl ::core::convert::From<i32> for #name {
1130 fn from(code: i32) -> Self {
1131 match code {
1132 #(#from_arms,)*
1133 other => #name::#fb_ident(other),
1134 }
1135 }
1136 }
1137
1138 impl ::serde::Serialize for #name {
1139 fn serialize<S: ::serde::Serializer>(
1140 &self,
1141 serializer: S,
1142 ) -> ::core::result::Result<S::Ok, S::Error> {
1143 serializer.serialize_i32(self.code())
1144 }
1145 }
1146
1147 impl<'de> ::serde::Deserialize<'de> for #name {
1148 fn deserialize<D: ::serde::Deserializer<'de>>(
1149 deserializer: D,
1150 ) -> ::core::result::Result<Self, D::Error> {
1151 let n = <i32 as ::serde::Deserialize>::deserialize(deserializer)?;
1152 ::core::result::Result::Ok(<Self as ::core::convert::From<i32>>::from(n))
1153 }
1154 }
1155 }
1156}
1157
1158fn expand_wire_enum_tagged(
1161 name: &syn::Ident,
1162 variants: &syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>,
1163 discriminator: &str,
1164) -> proc_macro2::TokenStream {
1165 let mut infos = Vec::with_capacity(variants.len());
1166 for v in variants {
1167 match read_variant(v) {
1168 Ok(info) => infos.push(info),
1169 Err(e) => return e.to_compile_error(),
1170 }
1171 }
1172
1173 let mut seen: std::collections::HashMap<String, syn::Ident> = Default::default();
1174 let mut fallback: Option<&VariantInfo> = None;
1175
1176 for info in &infos {
1177 if info.is_fallback {
1178 if fallback.is_some() {
1179 return syn::Error::new_spanned(
1180 &info.ident,
1181 "only one #[wire_fallback] is allowed",
1182 )
1183 .to_compile_error();
1184 }
1185 let ok = matches!(
1187 &info.fields,
1188 syn::Fields::Named(n)
1189 if n.named.len() == 1
1190 && n.named
1191 .first()
1192 .unwrap()
1193 .ident
1194 .as_ref()
1195 .map(|i| i == "tag")
1196 .unwrap_or(false)
1197 );
1198 if !ok {
1199 return syn::Error::new_spanned(
1200 &info.ident,
1201 "tagged #[wire_fallback] must have exactly { tag: String }",
1202 )
1203 .to_compile_error();
1204 }
1205 if info.wire.is_some() {
1206 return syn::Error::new_spanned(
1207 &info.ident,
1208 "#[wire_fallback] variant must not have #[wire = \"...\"]",
1209 )
1210 .to_compile_error();
1211 }
1212 fallback = Some(info);
1213 continue;
1214 }
1215 let Some(VariantWire::Str(s)) = &info.wire else {
1216 return syn::Error::new_spanned(&info.ident, "variant needs #[wire = \"...\"]")
1217 .to_compile_error();
1218 };
1219 if let Some(prev) = seen.insert(s.clone(), info.ident.clone()) {
1220 return syn::Error::new_spanned(
1221 &info.ident,
1222 format!("duplicate #[wire = \"{s}\"]; already used by {prev}"),
1223 )
1224 .to_compile_error();
1225 }
1226 for alias in &info.aliases {
1227 if let Some(prev) = seen.insert(alias.clone(), info.ident.clone()) {
1228 return syn::Error::new_spanned(
1229 &info.ident,
1230 format!(
1231 "#[wire_alias = \"{alias}\"] collides with wire tag from variant {prev}"
1232 ),
1233 )
1234 .to_compile_error();
1235 }
1236 }
1237 }
1238
1239 let wire_tag_arms: Vec<_> = infos
1242 .iter()
1243 .map(|info| {
1244 let id = &info.ident;
1245 if info.is_fallback {
1246 quote! { #name::#id { tag } => tag.as_str() }
1248 } else {
1249 let VariantWire::Str(s) = info.wire.as_ref().unwrap() else {
1250 unreachable!()
1251 };
1252 match &info.fields {
1253 syn::Fields::Unit => quote! { #name::#id => #s },
1254 syn::Fields::Named(_) => quote! { #name::#id { .. } => #s },
1255 syn::Fields::Unnamed(_) => quote! { #name::#id(..) => #s },
1256 }
1257 }
1258 })
1259 .collect();
1260
1261 let serialize_arms: Vec<_> = infos
1264 .iter()
1265 .map(|info| {
1266 let id = &info.ident;
1267 if info.is_fallback {
1268 quote! { #name::#id { tag: _ } => {} }
1270 } else {
1271 match &info.fields {
1272 syn::Fields::Unit => quote! { #name::#id => {} },
1273 syn::Fields::Named(named) => {
1274 let bindings: Vec<proc_macro2::TokenStream> = named
1275 .named
1276 .iter()
1277 .map(|f| {
1278 let id = f.ident.as_ref().unwrap();
1279 if field_has_wire_skip(&f.attrs) {
1280 quote! { #id: _ }
1281 } else {
1282 quote! { #id }
1283 }
1284 })
1285 .collect();
1286 let entries: Vec<proc_macro2::TokenStream> = named
1287 .named
1288 .iter()
1289 .filter(|f| !field_has_wire_skip(&f.attrs))
1290 .map(|f| {
1291 let id = f.ident.as_ref().unwrap();
1292 let key = id.to_string();
1293 if is_option_type(&f.ty) {
1294 quote! {
1295 if let ::core::option::Option::Some(__v) = #id {
1296 ::serde::ser::SerializeMap::serialize_entry(
1297 &mut map, #key, __v
1298 )?;
1299 }
1300 }
1301 } else {
1302 quote! {
1303 ::serde::ser::SerializeMap::serialize_entry(
1304 &mut map, #key, #id
1305 )?;
1306 }
1307 }
1308 })
1309 .collect();
1310 quote! {
1311 #name::#id { #(#bindings),* } => {
1312 #(#entries)*
1313 }
1314 }
1315 }
1316 syn::Fields::Unnamed(_) => {
1317 quote! {
1318 compile_error!("tagged WireEnum tuple variants are not supported — use named fields or unit");
1319 }
1320 }
1321 }
1322 }
1323 })
1324 .collect();
1325
1326 let tag_ident = quote::format_ident!("{}Tag", name);
1329
1330 let mut tag_variant_tokens: Vec<proc_macro2::TokenStream> = Vec::new();
1331 for info in &infos {
1332 let id = &info.ident;
1333 if info.is_fallback {
1334 tag_variant_tokens.push(quote! {
1335 #[doc = "Unknown wire tag — captured for forward compatibility."]
1336 #[wire_fallback]
1337 Unknown(::std::string::String)
1338 });
1339 continue;
1340 }
1341 let VariantWire::Str(primary) = info.wire.as_ref().unwrap() else {
1342 unreachable!()
1343 };
1344 let alias_attrs = info.aliases.iter().map(|a| quote! { #[wire_alias = #a] });
1350 tag_variant_tokens.push(quote! {
1351 #[wire = #primary]
1352 #(#alias_attrs)*
1353 #id
1354 });
1355 }
1356
1357 let discriminator_lit = discriminator;
1360
1361 quote! {
1362 impl #name {
1363 pub fn wire_tag(&self) -> &str {
1366 match self {
1367 #(#wire_tag_arms,)*
1368 }
1369 }
1370
1371 #[inline]
1373 pub fn tag_name(&self) -> &str {
1374 self.wire_tag()
1375 }
1376 }
1377
1378 impl ::serde::Serialize for #name {
1379 fn serialize<S: ::serde::Serializer>(
1380 &self,
1381 serializer: S,
1382 ) -> ::core::result::Result<S::Ok, S::Error> {
1383 use ::serde::ser::SerializeMap;
1384 let mut map = serializer.serialize_map(None)?;
1385 ::serde::ser::SerializeMap::serialize_entry(
1386 &mut map, #discriminator_lit, self.wire_tag()
1387 )?;
1388 match self {
1389 #(#serialize_arms,)*
1390 }
1391 ::serde::ser::SerializeMap::end(map)
1392 }
1393 }
1394
1395 #[doc = "Auto-generated by `#[derive(WireEnum)]`."]
1399 #[derive(Debug, Clone, PartialEq, Eq, ::wacore::WireEnum)]
1400 #[allow(clippy::enum_variant_names)]
1401 pub enum #tag_ident {
1402 #(#tag_variant_tokens,)*
1403 }
1404 }
1405}