serde_struct_tuple_enum_proc_macro/
lib.rs

1#![no_std]
2
3extern crate alloc;
4extern crate proc_macro;
5
6use alloc::{
7    fmt::format,
8    string::ToString,
9    vec::Vec,
10};
11
12use itertools::Itertools;
13use proc_macro::TokenStream;
14use proc_macro2::{
15    Span,
16    TokenTree,
17};
18use quote::quote;
19use syn::{
20    parse::{
21        Parse,
22        ParseStream,
23    },
24    parse_macro_input,
25    Data,
26    DeriveInput,
27    Error,
28    Expr,
29    Field,
30    Ident,
31    Lit,
32    Meta,
33};
34
35struct VariantAttrs {
36    tag: Lit,
37}
38
39struct Variant {
40    ident: Ident,
41    attrs: VariantAttrs,
42    field: Field,
43}
44
45struct Input {
46    ident: Ident,
47    tag: Ident,
48    variants: Vec<Variant>,
49}
50
51impl Parse for Input {
52    fn parse(input: ParseStream) -> syn::Result<Self> {
53        let call_site = Span::call_site();
54        let input = DeriveInput::parse(input)?;
55        let ident = input.ident;
56        let data = match input.data {
57            Data::Enum(data) => data,
58            _ => return Err(Error::new(call_site, "input must be a struct")),
59        };
60        let mut tag = None;
61        for attr in input.attrs {
62            if let Meta::List(list) = attr.meta {
63                if list.path.is_ident("tag") {
64                    let mut tokens = list.tokens.into_iter();
65                    match tokens.next() {
66                        Some(TokenTree::Ident(ident)) => {
67                            tag = Some(ident);
68                        }
69                        Some(_) | None => {
70                            return Err(Error::new(call_site, "tag attribute must have a type"))
71                        }
72                    }
73                }
74            }
75        }
76        let tag = match tag {
77            Some(tag) => tag,
78            None => return Err(Error::new(call_site, "missing tag attribute")),
79        };
80        let variants = data
81            .variants
82            .into_iter()
83            .map(|variant| {
84                let mut tag = None;
85                for attr in variant.attrs {
86                    if let Meta::NameValue(name_value) = attr.meta {
87                        if name_value.path.is_ident("tag") {
88                            tag = match name_value.value {
89                                Expr::Lit(lit) => Some(lit.lit),
90                                _ => return Err(Error::new(call_site, "tag must be a literal")),
91                            }
92                        }
93                    }
94                }
95                let tag = match tag {
96                    Some(tag) => tag,
97                    None => {
98                        return Err(Error::new(
99                            call_site,
100                            "enum variants must have a tag attribute",
101                        ))
102                    }
103                };
104                let attrs = VariantAttrs { tag };
105                if variant.fields.len() != 1 {
106                    return Err(Error::new(call_site, "enum variants must have one field"));
107                }
108                let field = variant.fields.into_iter().next().unwrap();
109                Ok(Variant {
110                    ident: variant.ident,
111                    attrs,
112                    field,
113                })
114            })
115            .collect::<Result<Vec<_>, _>>()?;
116        Ok(Self {
117            ident,
118            tag,
119            variants,
120        })
121    }
122}
123
124/// Implements `serde::Deserialize` for the enum, assuming each enum variant is a simple wrapper
125/// around implementations of `serde_struct_tuple::DeserializeStructTuple`.
126#[proc_macro_derive(DeserializeStructTupleEnum, attributes(tag))]
127pub fn derive_deserialize_struct_tuple_enum(input: TokenStream) -> TokenStream {
128    let input = parse_macro_input!(input as Input);
129    let call_site = Span::call_site();
130
131    let ident = input.ident;
132    let visitor_ident = Ident::new(&format(format_args!("{ident}Visitor")), call_site);
133
134    let tag = input.tag;
135
136    let match_codes = input
137        .variants
138        .iter()
139        .map(|variant| {
140            let variant_ident = &variant.ident;
141            let code = &variant.attrs.tag;
142            let field_ty = &variant.field.ty;
143            quote! {
144                #code => Ok(#ident::#variant_ident(#field_ty::visitor().visit_seq(value)?))
145            }
146        })
147        .collect::<Vec<_>>();
148
149    quote! {
150        impl<'de> serde::Deserialize<'de> for #ident {
151            fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> where D: serde::Deserializer<'de> {
152                struct #visitor_ident;
153
154                impl<'de> serde::de::Visitor<'de> for #visitor_ident {
155                    type Value = #ident;
156
157                    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
158                        formatter.write_fmt(format_args!("{} tuple", stringify!(#ident)))
159                    }
160
161                    fn visit_seq<A>(self, mut value: A) -> Result<Self::Value, A::Error>
162                    where
163                        A: serde::de::SeqAccess<'de>,
164                    {
165                        let tag: #tag = value.next_element()?.ok_or_else(|| serde::de::Error::missing_field(stringify!(#ident)))?;
166                        match tag {
167                            #(#match_codes,)*
168                            _ => Err(serde::de::Error::invalid_value(serde::de::Unexpected::TupleVariant, &self)),
169                        }
170                    }
171                }
172
173                deserializer.deserialize_seq(#visitor_ident)
174            }
175        }
176    }.into()
177}
178
179/// Implements `serde::Serialize` for the enum, assuming each enum variant is a simple wrapper
180/// around implementations of `serde_struct_tuple::SerializeStructTuple`.
181#[proc_macro_derive(SerializeStructTupleEnum, attributes(tag))]
182pub fn derive_serialize_struct_tuple_enum(input: TokenStream) -> TokenStream {
183    let input = parse_macro_input!(input as Input);
184
185    let ident = input.ident;
186    let tag_type = input.tag;
187
188    let (serialize_variant, tag_variant, tag_const_variant): (Vec<_>, Vec<_>, Vec<_>) = input
189        .variants
190        .iter()
191        .map(|variant| {
192            let variant_ident = &variant.ident;
193            let variant_const_ident = Ident::new(
194                &format(format_args!(
195                    "{}_TAG",
196                    variant_ident.to_string().to_uppercase()
197                )),
198                variant_ident.span(),
199            );
200            let tag = &variant.attrs.tag;
201            (
202                quote! {
203                    #ident::#variant_ident(inner) => {
204                        let mut seq = serializer.serialize_seq(None)?;
205                        seq.serialize_element(&#tag)?;
206                        inner.serialize_fields_to_seq(&mut seq)?;
207                        seq.end()
208                    }
209                },
210                quote! {
211                    #ident::#variant_ident(_) => #tag,
212                },
213                quote! {
214                    pub const #variant_const_ident: #tag_type = #tag;
215                },
216            )
217        })
218        .multiunzip();
219
220    quote! {
221        impl serde::Serialize for #ident {
222            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
223            where
224                S: serde::Serializer {
225                    use serde::ser::SerializeSeq;
226                    match self {
227                        #(#serialize_variant)*
228                    }
229                }
230        }
231
232        impl #ident {
233            #(#tag_const_variant)*
234
235            pub fn tag(&self) -> #tag_type {
236                match self {
237                    #(#tag_variant)*
238                }
239            }
240        }
241    }
242    .into()
243}