tagged_serde/
lib.rs

1use proc_macro::{self, TokenStream};
2use proc_macro2::Span;
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput, Fields, FieldsUnnamed, Ident};
5
6#[proc_macro_derive(TaggedSerde, attributes(tagged_serde))]
7pub fn derive(input: TokenStream) -> TokenStream {
8    let input: DeriveInput = parse_macro_input!(input);
9    let ident = input.ident;
10
11    let syn::Data::Enum(input) = input.data else {
12        // TODO: make this nice
13        return quote!{
14            compile_error!("not an enum");
15        }.into();
16        // panic!("not an enum");
17    };
18
19    let variants = input.variants.iter().map(|v| {
20        let variant_name = &v.ident;
21
22        let tag = v
23            .attrs
24            .iter()
25            .find(|attr| {
26                attr.meta
27                    .path()
28                    .get_ident()
29                    .map_or(false, |i| i == "tagged_serde")
30            })
31            .map(|attr| {
32                let nv = attr.meta.require_name_value().expect("name-value");
33                &nv.value
34            })
35            .expect("No enum tag found for {variant_name}");
36
37        let number_of_fields = match &v.fields {
38            Fields::Unnamed(FieldsUnnamed {
39                paren_token: _,
40                unnamed,
41            }) => Some(unnamed.len()),
42            Fields::Unit => None,
43            _ => unimplemented!(),
44        };
45
46        if let Some(number_of_fields) = number_of_fields {
47            let field_names : Vec<_> = (0..number_of_fields).map(|n| Ident::new(&format!("field{n}"), Span::call_site())).collect();
48
49            quote! {
50                // FIXME don't hardcode u64
51                #ident::#variant_name(#( #field_names ),*) => (#tag as u64, #( #field_names ),*).serialize(serializer)
52            }
53        } else {
54            quote! {
55                #ident::#variant_name => (#tag as u64).serialize(serializer)
56            }
57        }
58    });
59
60    let deser_variants = input.variants.iter().map(|v| {
61        let variant_name = &v.ident;
62
63        let tag = v
64            .attrs
65            .iter()
66            .find(|attr| {
67                attr.meta
68                    .path()
69                    .get_ident()
70                    .map_or(false, |i| i == "tagged_serde")
71            })
72            .map(|attr| {
73                let nv = attr.meta.require_name_value().expect("name-value");
74                &nv.value
75            })
76            .expect("No enum tag found for {variant_name}");
77
78        let number_of_fields = match &v.fields {
79            Fields::Unnamed(FieldsUnnamed {
80                paren_token: _,
81                unnamed,
82            }) => Some(unnamed.len()),
83            Fields::Unit => None,
84            _ => unimplemented!(),
85        };
86
87        let variant_pattern = if let Some(number_of_fields) = number_of_fields {
88            let variant_args: Vec<_> = (0..number_of_fields)
89                .map(|_| {
90                    quote! {
91                        seq
92                            .next_element().map_err(|e| A::Error::custom(format!("failed to read variant with tag {}: {}", tag, e)))?
93                            .ok_or_else(|| A::Error::custom(format!("failed to read variant with tag {}", tag)))?
94                    }
95                })
96                .collect();
97            quote! {
98                (#( #variant_args ),*)
99            }
100        } else {
101            quote! {
102            }
103        };
104
105        quote! {
106            #tag => {
107                Ok(#ident::#variant_name #variant_pattern)
108            }
109        }
110    });
111
112    // FIXME don't hardcode u64 in the deserializer tag
113    let output = quote! {
114        impl ::serde::Serialize for #ident {
115            fn serialize<S>(&self, serializer: S) -> ::core::result::Result<S::Ok, S::Error>
116            where
117                S: ::serde::Serializer,
118            {
119                match self {
120                    #( #variants ),*
121                }
122            }
123        }
124
125        impl<'de> Deserialize<'de> for #ident {
126            fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
127            where
128                D: serde::Deserializer<'de>,
129            {
130                use serde::de::Error;
131                struct Visitor;
132
133                impl<'d> serde::de::Visitor<'d> for Visitor {
134                    type Value = #ident;
135
136                    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
137                        formatter.write_str("either a string or an int")
138                    }
139
140                    fn visit_seq<A: serde::de::SeqAccess<'d>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
141                        let tag: u64 = seq
142                            .next_element()?
143                            .ok_or_else(|| A::Error::custom("failed to read logger field tag"))?;
144                        match tag {
145                            #( #deser_variants ),*
146                            _ => Err(A::Error::custom(format!("unknown tag {} when deserializing {}", tag, stringify!(#ident)))),
147                        }
148                    }
149                }
150
151                // TODO: make it a tuple with 2 fields: (tag, rest)
152                // We don't know yet how many fields to expect. We're abusing
153                // the fact that the nix serde implementation doesn't actually
154                // look at the size of the tuple.
155                deserializer.deserialize_tuple(usize::MAX, Visitor)
156            }
157        }
158    };
159    output.into()
160}