serde_struct_tuple_enum_proc_macro/
lib.rs1#![no_std]
2
3extern crate alloc;
4extern crate proc_macro;
5
6use alloc::{
7 fmt::format,
8 string::ToString,
9 vec::Vec,
10};
11
12use itertools::Itertools;
13use proc_macro::TokenStream;
14use proc_macro2::{
15 Span,
16 TokenTree,
17};
18use quote::quote;
19use syn::{
20 parse::{
21 Parse,
22 ParseStream,
23 },
24 parse_macro_input,
25 Data,
26 DeriveInput,
27 Error,
28 Expr,
29 Field,
30 Ident,
31 Lit,
32 Meta,
33};
34
35struct VariantAttrs {
36 tag: Lit,
37}
38
39struct Variant {
40 ident: Ident,
41 attrs: VariantAttrs,
42 field: Field,
43}
44
45struct Input {
46 ident: Ident,
47 tag: Ident,
48 variants: Vec<Variant>,
49}
50
51impl Parse for Input {
52 fn parse(input: ParseStream) -> syn::Result<Self> {
53 let call_site = Span::call_site();
54 let input = DeriveInput::parse(input)?;
55 let ident = input.ident;
56 let data = match input.data {
57 Data::Enum(data) => data,
58 _ => return Err(Error::new(call_site, "input must be a struct")),
59 };
60 let mut tag = None;
61 for attr in input.attrs {
62 if let Meta::List(list) = attr.meta {
63 if list.path.is_ident("tag") {
64 let mut tokens = list.tokens.into_iter();
65 match tokens.next() {
66 Some(TokenTree::Ident(ident)) => {
67 tag = Some(ident);
68 }
69 Some(_) | None => {
70 return Err(Error::new(call_site, "tag attribute must have a type"))
71 }
72 }
73 }
74 }
75 }
76 let tag = match tag {
77 Some(tag) => tag,
78 None => return Err(Error::new(call_site, "missing tag attribute")),
79 };
80 let variants = data
81 .variants
82 .into_iter()
83 .map(|variant| {
84 let mut tag = None;
85 for attr in variant.attrs {
86 if let Meta::NameValue(name_value) = attr.meta {
87 if name_value.path.is_ident("tag") {
88 tag = match name_value.value {
89 Expr::Lit(lit) => Some(lit.lit),
90 _ => return Err(Error::new(call_site, "tag must be a literal")),
91 }
92 }
93 }
94 }
95 let tag = match tag {
96 Some(tag) => tag,
97 None => {
98 return Err(Error::new(
99 call_site,
100 "enum variants must have a tag attribute",
101 ))
102 }
103 };
104 let attrs = VariantAttrs { tag };
105 if variant.fields.len() != 1 {
106 return Err(Error::new(call_site, "enum variants must have one field"));
107 }
108 let field = variant.fields.into_iter().next().unwrap();
109 Ok(Variant {
110 ident: variant.ident,
111 attrs,
112 field,
113 })
114 })
115 .collect::<Result<Vec<_>, _>>()?;
116 Ok(Self {
117 ident,
118 tag,
119 variants,
120 })
121 }
122}
123
124#[proc_macro_derive(DeserializeStructTupleEnum, attributes(tag))]
127pub fn derive_deserialize_struct_tuple_enum(input: TokenStream) -> TokenStream {
128 let input = parse_macro_input!(input as Input);
129 let call_site = Span::call_site();
130
131 let ident = input.ident;
132 let visitor_ident = Ident::new(&format(format_args!("{ident}Visitor")), call_site);
133
134 let tag = input.tag;
135
136 let match_codes = input
137 .variants
138 .iter()
139 .map(|variant| {
140 let variant_ident = &variant.ident;
141 let code = &variant.attrs.tag;
142 let field_ty = &variant.field.ty;
143 quote! {
144 #code => Ok(#ident::#variant_ident(#field_ty::visitor().visit_seq(value)?))
145 }
146 })
147 .collect::<Vec<_>>();
148
149 quote! {
150 impl<'de> serde::Deserialize<'de> for #ident {
151 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> where D: serde::Deserializer<'de> {
152 struct #visitor_ident;
153
154 impl<'de> serde::de::Visitor<'de> for #visitor_ident {
155 type Value = #ident;
156
157 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
158 formatter.write_fmt(format_args!("{} tuple", stringify!(#ident)))
159 }
160
161 fn visit_seq<A>(self, mut value: A) -> Result<Self::Value, A::Error>
162 where
163 A: serde::de::SeqAccess<'de>,
164 {
165 let tag: #tag = value.next_element()?.ok_or_else(|| serde::de::Error::missing_field(stringify!(#ident)))?;
166 match tag {
167 #(#match_codes,)*
168 _ => Err(serde::de::Error::invalid_value(serde::de::Unexpected::TupleVariant, &self)),
169 }
170 }
171 }
172
173 deserializer.deserialize_seq(#visitor_ident)
174 }
175 }
176 }.into()
177}
178
179#[proc_macro_derive(SerializeStructTupleEnum, attributes(tag))]
182pub fn derive_serialize_struct_tuple_enum(input: TokenStream) -> TokenStream {
183 let input = parse_macro_input!(input as Input);
184
185 let ident = input.ident;
186 let tag_type = input.tag;
187
188 let (serialize_variant, tag_variant, tag_const_variant): (Vec<_>, Vec<_>, Vec<_>) = input
189 .variants
190 .iter()
191 .map(|variant| {
192 let variant_ident = &variant.ident;
193 let variant_const_ident = Ident::new(
194 &format(format_args!(
195 "{}_TAG",
196 variant_ident.to_string().to_uppercase()
197 )),
198 variant_ident.span(),
199 );
200 let tag = &variant.attrs.tag;
201 (
202 quote! {
203 #ident::#variant_ident(inner) => {
204 let mut seq = serializer.serialize_seq(None)?;
205 seq.serialize_element(&#tag)?;
206 inner.serialize_fields_to_seq(&mut seq)?;
207 seq.end()
208 }
209 },
210 quote! {
211 #ident::#variant_ident(_) => #tag,
212 },
213 quote! {
214 pub const #variant_const_ident: #tag_type = #tag;
215 },
216 )
217 })
218 .multiunzip();
219
220 quote! {
221 impl serde::Serialize for #ident {
222 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
223 where
224 S: serde::Serializer {
225 use serde::ser::SerializeSeq;
226 match self {
227 #(#serialize_variant)*
228 }
229 }
230 }
231
232 impl #ident {
233 #(#tag_const_variant)*
234
235 pub fn tag(&self) -> #tag_type {
236 match self {
237 #(#tag_variant)*
238 }
239 }
240 }
241 }
242 .into()
243}