serde_struct_tuple_proc_macro/
lib.rs

1#![no_std]
2
3extern crate alloc;
4extern crate proc_macro;
5
6use alloc::{
7    fmt::format,
8    vec::Vec,
9};
10
11use proc_macro::TokenStream;
12use proc_macro2::Span;
13use quote::quote;
14use syn::{
15    Error,
16    Field,
17    Ident,
18    ItemStruct,
19    Meta,
20    Path,
21    Type,
22    parse::{
23        Parse,
24        ParseStream,
25    },
26    parse_macro_input,
27};
28
29#[derive(Default)]
30enum DefaultAttr {
31    #[default]
32    False,
33    True,
34    Path(Path),
35}
36
37impl DefaultAttr {
38    pub fn can_be_default(&self) -> bool {
39        match self {
40            Self::False => false,
41            _ => true,
42        }
43    }
44}
45
46#[derive(Default)]
47struct InputFieldAttrs {
48    default: DefaultAttr,
49    skip_serializing_if: Option<Path>,
50}
51
52struct InputField {
53    ident: Option<Ident>,
54    ty: Type,
55    attrs: InputFieldAttrs,
56}
57
58struct Input {
59    ident: Ident,
60    fields: Vec<InputField>,
61}
62
63fn parse_input_field_attrs(field: &Field) -> syn::Result<InputFieldAttrs> {
64    let serde_attr = field.attrs.iter().find(|attr| {
65        if let Meta::List(list) = &attr.meta {
66            if list.path.is_ident("serde_struct_tuple") {
67                return true;
68            }
69        }
70        false
71    });
72    let serde_attr = match serde_attr {
73        Some(attr) => attr,
74        None => return Ok(InputFieldAttrs::default()),
75    };
76
77    let mut default = DefaultAttr::False;
78    let mut skip_serializing_if = None;
79    serde_attr.parse_nested_meta(|meta| {
80        if meta.path.is_ident("default") {
81            default = match meta.value() {
82                Ok(value) => DefaultAttr::Path(value.parse::<Path>()?),
83                Err(_) => DefaultAttr::True,
84            }
85        }
86        if meta.path.is_ident("skip_serializing_if") {
87            let value = meta.value()?;
88            skip_serializing_if = Some(value.parse::<Path>()?);
89        }
90        Ok(())
91    })?;
92    Ok(InputFieldAttrs {
93        default,
94        skip_serializing_if,
95    })
96}
97
98impl Parse for Input {
99    fn parse(input: ParseStream) -> syn::Result<Self> {
100        let call_site = Span::call_site();
101        let input = match ItemStruct::parse(input) {
102            Ok(item) => item,
103            Err(_) => return Err(Error::new(call_site, "input must be a struct")),
104        };
105        let ident = input.ident;
106        let mut fields = Vec::new();
107        let mut defaulted = false;
108        for field in input.fields {
109            let attrs = parse_input_field_attrs(&field)?;
110            if attrs.default.can_be_default() {
111                defaulted = true
112            } else if defaulted {
113                return Err(Error::new(
114                    call_site,
115                    "fields after a default field must also be default",
116                ));
117            }
118            fields.push(InputField {
119                ident: field.ident,
120                ty: field.ty,
121                attrs,
122            });
123        }
124        Ok(Input { ident, fields })
125    }
126}
127
128/// Implements `serde_struct_tuple::DeserializeStructTuple` and `serde::Deserialize` for the
129/// struct.
130#[proc_macro_derive(DeserializeStructTuple, attributes(serde_struct_tuple))]
131pub fn derive_deserialize_struct_tuple(input: TokenStream) -> TokenStream {
132    let input = parse_macro_input!(input as Input);
133    let call_site = Span::call_site();
134
135    let ident = input.ident;
136    let visitor_ident = Ident::new(&format(format_args!("{ident}Visitor")), call_site);
137    let field_deserializers = input
138        .fields
139        .iter()
140        .map(|field| {
141            let ident = field.ident.as_ref().unwrap();
142            let field_ty = &field.ty;
143            let if_empty = match &field.attrs.default {
144                DefaultAttr::False => {
145                    quote!(return Err(serde::de::Error::missing_field(stringify!(#ident))))
146                }
147                DefaultAttr::True => quote!(<#field_ty as Default>::default()),
148                DefaultAttr::Path(path) => quote!(#path()),
149            };
150            quote!(#ident: match value.next_element()? {
151                Some(value) => value,
152                None => #if_empty,
153            })
154        })
155        .collect::<Vec<_>>();
156
157    quote! {
158        impl serde_struct_tuple::DeserializeStructTuple for #ident {
159            type Value = #ident;
160            fn visitor<'de>() -> impl serde::de::Visitor<'de, Value = Self::Value> {
161                struct #visitor_ident;
162                impl<'de> serde::de::Visitor<'de> for #visitor_ident {
163                    type Value = #ident;
164
165                    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
166                        formatter.write_fmt(format_args!("{} tuple", stringify!(#ident)))
167                    }
168
169                    fn visit_seq<A>(self, mut value: A) -> Result<Self::Value, A::Error>
170                    where
171                        A: serde::de::SeqAccess<'de>,
172                    {
173                        Ok(#ident {
174                            #(#field_deserializers,)*
175                        })
176                    }
177                }
178
179                #visitor_ident
180            }
181        }
182
183
184        impl<'de> serde::Deserialize<'de> for #ident {
185            fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> where D: serde::Deserializer<'de> {
186                deserializer.deserialize_seq(Self::visitor())
187            }
188        }
189    }
190    .into()
191}
192
193/// Implements `serde_struct_tuple::SerializeStructTuple` and `serde::Serialize` for the struct.
194#[proc_macro_derive(SerializeStructTuple, attributes(serde_struct_tuple))]
195pub fn derive_serialize_struct_tuple(input: TokenStream) -> TokenStream {
196    let input = parse_macro_input!(input as Input);
197
198    let ident = input.ident;
199
200    let run_skipped_serializers = quote! {
201        for skipped in skipped {
202            skipped(seq)?;
203        }
204        skipped = std::vec::Vec::default();
205    };
206
207    let field_serializers = input
208        .fields
209        .iter()
210        .map(|field| {
211            let ident = field.ident.as_ref().unwrap();
212            match &field.attrs.skip_serializing_if {
213                Some(skip_serializing_if) => quote! {
214                    if #skip_serializing_if(&self.#ident) {
215                        skipped.push(std::boxed::Box::new(|seq| seq.serialize_element(&self.#ident)));
216                    } else {
217                        #run_skipped_serializers
218                        seq.serialize_element(&self.#ident)?;
219                    }
220                },
221                None => quote! {
222                    #run_skipped_serializers
223                        seq.serialize_element(&self.#ident)?;
224                },
225            }
226        })
227        .collect::<Vec<_>>();
228
229    quote! {
230        impl serde_struct_tuple::SerializeStructTuple for #ident {
231            fn serialize_fields_to_seq<S>(&self, seq: &mut S) -> core::result::Result<(), S::Error> where S: serde::ser::SerializeSeq {
232                use serde::ser::SerializeSeq;
233                let mut skipped: std::vec::Vec<std::boxed::Box<dyn core::ops::FnOnce(&mut S) -> core::result::Result<(), S::Error>>> = std::vec::Vec::default();
234                #(#field_serializers)*
235                Ok(())
236            }
237        }
238
239        impl serde::Serialize for #ident {
240            fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error> where S: serde::Serializer {
241                use serde::ser::SerializeSeq;
242                let mut seq = serializer.serialize_seq(None)?;
243                self.serialize_fields_to_seq(&mut seq)?;
244                seq.end()
245            }
246        }
247    }
248    .into()
249}