1use proc_macro::{self, TokenStream};
2use proc_macro2::Span;
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput, Fields, FieldsUnnamed, Ident};
5
6#[proc_macro_derive(TaggedSerde, attributes(tagged_serde))]
7pub fn derive(input: TokenStream) -> TokenStream {
8 let input: DeriveInput = parse_macro_input!(input);
9 let ident = input.ident;
10
11 let syn::Data::Enum(input) = input.data else {
12 return quote!{
14 compile_error!("not an enum");
15 }.into();
16 };
18
19 let variants = input.variants.iter().map(|v| {
20 let variant_name = &v.ident;
21
22 let tag = v
23 .attrs
24 .iter()
25 .find(|attr| {
26 attr.meta
27 .path()
28 .get_ident()
29 .map_or(false, |i| i == "tagged_serde")
30 })
31 .map(|attr| {
32 let nv = attr.meta.require_name_value().expect("name-value");
33 &nv.value
34 })
35 .expect("No enum tag found for {variant_name}");
36
37 let number_of_fields = match &v.fields {
38 Fields::Unnamed(FieldsUnnamed {
39 paren_token: _,
40 unnamed,
41 }) => Some(unnamed.len()),
42 Fields::Unit => None,
43 _ => unimplemented!(),
44 };
45
46 if let Some(number_of_fields) = number_of_fields {
47 let field_names : Vec<_> = (0..number_of_fields).map(|n| Ident::new(&format!("field{n}"), Span::call_site())).collect();
48
49 quote! {
50 #ident::#variant_name(#( #field_names ),*) => (#tag as u64, #( #field_names ),*).serialize(serializer)
52 }
53 } else {
54 quote! {
55 #ident::#variant_name => (#tag as u64).serialize(serializer)
56 }
57 }
58 });
59
60 let deser_variants = input.variants.iter().map(|v| {
61 let variant_name = &v.ident;
62
63 let tag = v
64 .attrs
65 .iter()
66 .find(|attr| {
67 attr.meta
68 .path()
69 .get_ident()
70 .map_or(false, |i| i == "tagged_serde")
71 })
72 .map(|attr| {
73 let nv = attr.meta.require_name_value().expect("name-value");
74 &nv.value
75 })
76 .expect("No enum tag found for {variant_name}");
77
78 let number_of_fields = match &v.fields {
79 Fields::Unnamed(FieldsUnnamed {
80 paren_token: _,
81 unnamed,
82 }) => Some(unnamed.len()),
83 Fields::Unit => None,
84 _ => unimplemented!(),
85 };
86
87 let variant_pattern = if let Some(number_of_fields) = number_of_fields {
88 let variant_args: Vec<_> = (0..number_of_fields)
89 .map(|_| {
90 quote! {
91 seq
92 .next_element().map_err(|e| A::Error::custom(format!("failed to read variant with tag {}: {}", tag, e)))?
93 .ok_or_else(|| A::Error::custom(format!("failed to read variant with tag {}", tag)))?
94 }
95 })
96 .collect();
97 quote! {
98 (#( #variant_args ),*)
99 }
100 } else {
101 quote! {
102 }
103 };
104
105 quote! {
106 #tag => {
107 Ok(#ident::#variant_name #variant_pattern)
108 }
109 }
110 });
111
112 let output = quote! {
114 impl ::serde::Serialize for #ident {
115 fn serialize<S>(&self, serializer: S) -> ::core::result::Result<S::Ok, S::Error>
116 where
117 S: ::serde::Serializer,
118 {
119 match self {
120 #( #variants ),*
121 }
122 }
123 }
124
125 impl<'de> Deserialize<'de> for #ident {
126 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
127 where
128 D: serde::Deserializer<'de>,
129 {
130 use serde::de::Error;
131 struct Visitor;
132
133 impl<'d> serde::de::Visitor<'d> for Visitor {
134 type Value = #ident;
135
136 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
137 formatter.write_str("either a string or an int")
138 }
139
140 fn visit_seq<A: serde::de::SeqAccess<'d>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
141 let tag: u64 = seq
142 .next_element()?
143 .ok_or_else(|| A::Error::custom("failed to read logger field tag"))?;
144 match tag {
145 #( #deser_variants ),*
146 _ => Err(A::Error::custom(format!("unknown tag {} when deserializing {}", tag, stringify!(#ident)))),
147 }
148 }
149 }
150
151 deserializer.deserialize_tuple(usize::MAX, Visitor)
156 }
157 }
158 };
159 output.into()
160}