1use proc_macro::TokenStream;
31use quote::quote;
32use syn::{Data, DeriveInput, Fields, parse_macro_input};
33
34#[proc_macro_derive(ProtocolNode, attributes(protocol, attr))]
65pub fn derive_protocol_node(input: TokenStream) -> TokenStream {
66 let input = parse_macro_input!(input as DeriveInput);
67
68 let name = &input.ident;
69
70 let tag = match extract_tag(&input.attrs) {
71 Ok(Some(tag)) => tag,
72 Ok(None) => {
73 return syn::Error::new_spanned(
74 &input.ident,
75 "ProtocolNode requires #[protocol(tag = \"...\")]",
76 )
77 .to_compile_error()
78 .into();
79 }
80 Err(e) => return e.to_compile_error().into(),
81 };
82
83 let fields = match &input.data {
84 Data::Struct(data) => match &data.fields {
85 Fields::Named(fields) => &fields.named,
86 Fields::Unit => return generate_empty_impl(name, &tag).into(),
87 _ => {
88 return syn::Error::new_spanned(
89 &input.ident,
90 "ProtocolNode only supports named fields or unit structs",
91 )
92 .to_compile_error()
93 .into();
94 }
95 },
96 _ => {
97 return syn::Error::new_spanned(
98 &input.ident,
99 "ProtocolNode can only be derived for structs",
100 )
101 .to_compile_error()
102 .into();
103 }
104 };
105
106 let mut attr_fields = Vec::new();
107 for field in fields {
108 match extract_attr_info(field) {
109 Ok(Some(attr_info)) => attr_fields.push(attr_info),
110 Ok(None) => {}
111 Err(e) => return e.to_compile_error().into(),
112 }
113 }
114
115 let attr_setters: Vec<_> = attr_fields
116 .iter()
117 .map(|info| {
118 let field_ident = &info.field_ident;
119 let attr_name = &info.attr_name;
120
121 match (&info.attr_type, info.optional) {
122 (AttrType::Jid, true) => {
123 quote! {
125 if let Some(jid) = self.#field_ident {
126 builder = builder.jid_attr(#attr_name, jid);
127 }
128 }
129 }
130 (AttrType::Jid, false) => {
131 quote! {
133 builder = builder.jid_attr(#attr_name, self.#field_ident);
134 }
135 }
136 (AttrType::String, true) => {
137 quote! {
139 if let Some(s) = self.#field_ident {
140 builder = builder.attr(#attr_name, s);
141 }
142 }
143 }
144 (AttrType::String, false) => {
145 quote! {
147 builder = builder.attr(#attr_name, self.#field_ident);
148 }
149 }
150 }
151 })
152 .collect();
153
154 let field_parsers: Vec<_> = attr_fields
155 .iter()
156 .map(|info| {
157 let field_ident = &info.field_ident;
158 let attr_name = &info.attr_name;
159
160 match (&info.attr_type, info.optional, &info.default) {
161 (AttrType::Jid, false, _) => {
162 quote! {
164 #field_ident: node.attrs().optional_jid(#attr_name)
165 .ok_or_else(|| ::anyhow::anyhow!("missing required attribute '{}'", #attr_name))?
166 }
167 }
168 (AttrType::Jid, true, _) => {
169 quote! {
171 #field_ident: node.attrs().optional_jid(#attr_name)
172 }
173 }
174 (AttrType::String, false, Some(default)) => {
175 quote! {
177 #field_ident: node.attrs().optional_string(#attr_name)
178 .map(|s| s.to_string())
179 .unwrap_or_else(|| #default.to_string())
180 }
181 }
182 (AttrType::String, false, None) => {
183 quote! {
185 #field_ident: node.attrs().required_string(#attr_name)?.to_string()
186 }
187 }
188 (AttrType::String, true, Some(default)) => {
189 quote! {
191 #field_ident: node.attrs().optional_string(#attr_name)
192 .map(|s| s.to_string())
193 .or_else(|| Some(#default.to_string()))
194 }
195 }
196 (AttrType::String, true, None) => {
197 quote! {
199 #field_ident: node.attrs().optional_string(#attr_name).map(|s| s.to_string())
200 }
201 }
202 }
203 })
204 .collect();
205
206 let all_have_defaults = attr_fields
208 .iter()
209 .all(|info| info.default.is_some() || info.optional);
210
211 let default_impl = if all_have_defaults {
212 let default_fields: Vec<_> = attr_fields
213 .iter()
214 .map(|info| {
215 let field_ident = &info.field_ident;
216 match (&info.attr_type, info.optional, &info.default) {
217 (_, true, Some(default)) => quote! { #field_ident: Some(#default.to_string()) },
218 (_, true, None) => quote! { #field_ident: None },
219 (AttrType::String, false, Some(default)) => {
220 quote! { #field_ident: #default.to_string() }
221 }
222 _ => unreachable!("all_have_defaults check should prevent this branch"),
223 }
224 })
225 .collect();
226
227 quote! {
228 impl ::core::default::Default for #name {
229 fn default() -> Self {
230 Self {
231 #(#default_fields),*
232 }
233 }
234 }
235 }
236 } else {
237 quote! {}
238 };
239
240 let expanded = quote! {
241 impl ::wa_rs_core::protocol::ProtocolNode for #name {
242 fn tag(&self) -> &'static str {
243 #tag
244 }
245
246 fn into_node(self) -> ::wa_rs_binary::node::Node {
247 let mut builder = ::wa_rs_binary::builder::NodeBuilder::new(#tag);
248 #(#attr_setters)*
249 builder.build()
250 }
251
252 fn try_from_node(node: &::wa_rs_binary::node::Node) -> ::anyhow::Result<Self> {
253 if node.tag != #tag {
254 return Err(::anyhow::anyhow!("expected <{}>, got <{}>", #tag, node.tag));
255 }
256 Ok(Self {
257 #(#field_parsers),*
258 })
259 }
260 }
261
262 #default_impl
263 };
264
265 expanded.into()
266}
267
268#[proc_macro_derive(EmptyNode, attributes(protocol))]
282pub fn derive_empty_node(input: TokenStream) -> TokenStream {
283 let input = parse_macro_input!(input as DeriveInput);
284
285 let name = &input.ident;
286
287 let tag = match extract_tag(&input.attrs) {
288 Ok(Some(tag)) => tag,
289 Ok(None) => {
290 return syn::Error::new_spanned(
291 &input.ident,
292 "EmptyNode requires #[protocol(tag = \"...\")]",
293 )
294 .to_compile_error()
295 .into();
296 }
297 Err(e) => return e.to_compile_error().into(),
298 };
299
300 generate_empty_impl(name, &tag).into()
301}
302
303fn generate_empty_impl(name: &syn::Ident, tag: &str) -> proc_macro2::TokenStream {
304 quote! {
305 impl ::wa_rs_core::protocol::ProtocolNode for #name {
306 fn tag(&self) -> &'static str {
307 #tag
308 }
309
310 fn into_node(self) -> ::wa_rs_binary::node::Node {
311 ::wa_rs_binary::builder::NodeBuilder::new(#tag).build()
312 }
313
314 fn try_from_node(node: &::wa_rs_binary::node::Node) -> ::anyhow::Result<Self> {
315 if node.tag != #tag {
316 return Err(::anyhow::anyhow!("expected <{}>, got <{}>", #tag, node.tag));
317 }
318 Ok(Self)
319 }
320 }
321
322 impl ::core::default::Default for #name {
323 fn default() -> Self {
324 Self
325 }
326 }
327 }
328}
329
330enum AttrType {
331 String,
332 Jid,
333}
334
335struct AttrFieldInfo {
336 field_ident: syn::Ident,
337 attr_name: String,
338 attr_type: AttrType,
339 optional: bool,
340 default: Option<String>,
341}
342
343fn extract_tag(attrs: &[syn::Attribute]) -> Result<Option<String>, syn::Error> {
344 for attr in attrs {
345 if attr.path().is_ident("protocol") {
346 let mut tag = None;
347 attr.parse_nested_meta(|meta| {
348 if meta.path.is_ident("tag") {
349 let value: syn::LitStr = meta.value()?.parse()?;
350 tag = Some(value.value());
351 }
352 Ok(())
353 })?;
354 if tag.is_some() {
355 return Ok(tag);
356 }
357 }
358 }
359 Ok(None)
360}
361
362fn extract_attr_info(field: &syn::Field) -> Result<Option<AttrFieldInfo>, syn::Error> {
363 let field_ident = match field.ident.clone() {
364 Some(ident) => ident,
365 None => return Ok(None),
366 };
367
368 let is_optional = is_option_type(&field.ty);
370
371 for attr in &field.attrs {
372 if attr.path().is_ident("attr") {
373 let mut attr_name = None;
374 let mut default = None;
375 let mut is_jid = false;
376 let mut explicit_optional = false;
377
378 attr.parse_nested_meta(|meta| {
379 if meta.path.is_ident("name") {
380 let value: syn::LitStr = meta.value()?.parse()?;
381 attr_name = Some(value.value());
382 } else if meta.path.is_ident("default") {
383 let value: syn::LitStr = meta.value()?.parse()?;
384 default = Some(value.value());
385 } else if meta.path.is_ident("jid") {
386 is_jid = true;
387 } else if meta.path.is_ident("optional") {
388 explicit_optional = true;
389 }
390 Ok(())
391 })?;
392
393 match attr_name {
394 Some(name) => {
395 let attr_type = if is_jid {
396 AttrType::Jid
397 } else {
398 AttrType::String
399 };
400
401 let optional = explicit_optional || is_optional;
403
404 return Ok(Some(AttrFieldInfo {
405 field_ident,
406 attr_name: name,
407 attr_type,
408 optional,
409 default,
410 }));
411 }
412 None => {
413 return Err(syn::Error::new_spanned(
414 attr,
415 "missing required `name` in #[attr(...)]",
416 ));
417 }
418 }
419 }
420 }
421 Ok(None)
422}
423
424fn is_option_type(ty: &syn::Type) -> bool {
426 if let syn::Type::Path(type_path) = ty
427 && let Some(segment) = type_path.path.segments.last()
428 {
429 return segment.ident == "Option";
430 }
431 false
432}
433
434#[proc_macro_derive(StringEnum, attributes(str, string_default))]
463pub fn derive_string_enum(input: TokenStream) -> TokenStream {
464 let input = parse_macro_input!(input as DeriveInput);
465
466 let name = &input.ident;
467
468 let variants = match &input.data {
469 Data::Enum(data) => &data.variants,
470 _ => {
471 return syn::Error::new_spanned(
472 &input.ident,
473 "StringEnum can only be derived for enums",
474 )
475 .to_compile_error()
476 .into();
477 }
478 };
479
480 let mut variant_infos = Vec::new();
481 let mut default_variant = None;
482 let mut seen_str_values: std::collections::HashMap<String, syn::Ident> =
483 std::collections::HashMap::new();
484
485 for variant in variants {
486 let variant_ident = &variant.ident;
487
488 if !matches!(variant.fields, syn::Fields::Unit) {
489 return syn::Error::new_spanned(
490 variant_ident,
491 "StringEnum only supports unit variants",
492 )
493 .to_compile_error()
494 .into();
495 }
496
497 let mut str_value = None;
498 let mut is_default = false;
499
500 for attr in &variant.attrs {
501 if attr.path().is_ident("str") {
502 if let syn::Meta::NameValue(nv) = &attr.meta
503 && let syn::Expr::Lit(expr_lit) = &nv.value
504 && let syn::Lit::Str(lit_str) = &expr_lit.lit
505 {
506 str_value = Some(lit_str.value());
507 }
508 } else if attr.path().is_ident("string_default") {
509 is_default = true;
510 }
511 }
512
513 let str_val = match str_value {
514 Some(v) => v,
515 None => {
516 return syn::Error::new_spanned(
517 variant_ident,
518 format!(
519 "StringEnum variant {} requires #[str = \"...\"] attribute",
520 variant_ident
521 ),
522 )
523 .to_compile_error()
524 .into();
525 }
526 };
527
528 if let Some(prev_variant) = seen_str_values.get(&str_val) {
529 return syn::Error::new_spanned(
530 variant_ident,
531 format!(
532 "duplicate #[str = \"{}\"] value; already used by variant `{}`",
533 str_val, prev_variant
534 ),
535 )
536 .to_compile_error()
537 .into();
538 }
539 seen_str_values.insert(str_val.clone(), variant_ident.clone());
540
541 if is_default {
542 if default_variant.is_some() {
543 return syn::Error::new_spanned(
544 variant_ident,
545 "Multiple #[string_default] attributes found; only one variant may be the default",
546 )
547 .to_compile_error()
548 .into();
549 }
550 default_variant = Some(variant_ident.clone());
551 }
552
553 variant_infos.push((variant_ident.clone(), str_val));
554 }
555
556 if variant_infos.is_empty() {
558 return syn::Error::new_spanned(
559 &input.ident,
560 "StringEnum cannot be derived for empty enums",
561 )
562 .to_compile_error()
563 .into();
564 }
565
566 let default_variant = default_variant.unwrap_or_else(|| variant_infos[0].0.clone());
568
569 let as_str_arms: Vec<_> = variant_infos
571 .iter()
572 .map(|(ident, str_val)| {
573 quote! { #name::#ident => #str_val }
574 })
575 .collect();
576
577 let try_from_arms: Vec<_> = variant_infos
579 .iter()
580 .map(|(ident, str_val)| {
581 quote! { #str_val => Ok(#name::#ident) }
582 })
583 .collect();
584
585 let expanded = quote! {
586 impl #name {
587 pub fn as_str(&self) -> &'static str {
589 match self {
590 #(#as_str_arms),*
591 }
592 }
593 }
594
595 impl ::core::fmt::Display for #name {
596 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
597 f.write_str(self.as_str())
598 }
599 }
600
601 impl ::core::convert::TryFrom<&str> for #name {
602 type Error = ::anyhow::Error;
603
604 fn try_from(value: &str) -> ::core::result::Result<Self, Self::Error> {
605 match value {
606 #(#try_from_arms),*,
607 _ => Err(::anyhow::anyhow!("unknown {}: {}", stringify!(#name), value)),
608 }
609 }
610 }
611
612 impl ::core::default::Default for #name {
613 fn default() -> Self {
614 #name::#default_variant
615 }
616 }
617 };
618
619 expanded.into()
620}