serde_deserialize_over_derive/
lib.rs

1//! Derive macros for serde-deserialize-over.
2
3#![recursion_limit = "256"]
4
5extern crate proc_macro;
6
7mod attr;
8
9use std::collections::HashSet;
10
11// use proc_macro::TokenStream;
12use proc_macro2::{Span, TokenStream};
13use proc_macro_crate::{crate_name, FoundCrate};
14use quote::{quote, ToTokens};
15use syn::{
16  parse_macro_input, parse_quote, punctuated::Punctuated, Attribute, Data, DeriveInput, Fields,
17  FieldsNamed, FieldsUnnamed, GenericParam, Ident, Path, Token, Type,
18};
19
20const CRATE_NAME: &str = "serde_deserialize_over";
21
22/// Derive macro for the `DeserializeOver` trait.
23#[proc_macro_derive(DeserializeOver, attributes(deserialize_over, serde))]
24pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
25  let input = parse_macro_input!(input as DeriveInput);
26  let crate_name =
27    crate_name("serde-deserialize-over").unwrap_or(FoundCrate::Name(CRATE_NAME.to_string()));
28  let crate_name = match crate_name {
29    FoundCrate::Name(name) => Ident::new(&name, Span::call_site()),
30    FoundCrate::Itself => Ident::new(CRATE_NAME, Span::call_site()),
31  };
32
33  let data = match input.data {
34    Data::Struct(ref data) => data.clone(),
35    Data::Enum(_) => panic!("`DeserializeOver` cannot be automatically derived for enums"),
36    Data::Union(_) => panic!("`DeserializeOver` cannot be automatically derived for unions"),
37  };
38
39  let res = match data.fields {
40    Fields::Named(fields) => impl_named_fields(input, crate_name, fields),
41    Fields::Unnamed(fields) => impl_unnamed_fields(input, crate_name, fields),
42    Fields::Unit => impl_unit(input, crate_name),
43  };
44
45  match res {
46    Ok(res) => {
47      // panic!("{}", res);
48      res.into()
49    }
50    Err(e) => e.to_compile_error().into(),
51  }
52}
53
54#[derive(Clone)]
55struct FieldInfo {
56  name: Ident,
57  ty: Type,
58  passthrough: bool,
59  deserialize_with: Option<Path>,
60  deserialize_merge_with: Option<Path>,
61
62  srcname: Option<String>,
63  enum_value: Ident,
64}
65
66impl FieldInfo {
67  fn build_de_wrapper(&self, export: &syn::Path) -> TokenStream {
68    let Self { name, ty, .. } = self;
69    let visname = Ident::new(&format!("FieldWrapper{}", self.enum_value), name.span());
70    let lt = syn::Lifetime::new("'_serde_deserialize_over_a", Span::call_site());
71
72    if self.passthrough {
73      if let Some(merge_fn) = &self.deserialize_merge_with {
74        quote::quote! {{
75          struct #visname<#lt>(&#lt mut #ty);
76
77          impl<'de> #export::DeserializeSeed<'de> for #visname<'_> {
78            type Value = ();
79
80            fn deserialize<D>(self, deserializer: D) -> #export::Result<Self::Value, D::Error>
81            where
82                D: #export::Deserializer<'de>
83            {
84              #merge_fn(deserializer, self.0)
85            }
86          }
87
88          #visname(&mut (self.0).#name)
89        }}
90      } else {
91        if self.deserialize_with.is_some() {
92          return quote::quote_spanned! {
93            name.span() => {
94              compile_error!(r#"Field uses both $[serde(deserialize_with)] and #[deserializer_over]. Use #[serde(with = "...")] so that the DeserializeOver derive will use a custom deserialize function."#);
95              unreachable!()
96            }
97          };
98        }
99
100        quote! { #export::DeserializeOverWrapper(&mut (self.0).#name) }
101      }
102    } else {
103      if let Some(de_fn) = &self.deserialize_with {
104        quote::quote! {{
105          struct #visname<#lt>(&#lt mut #ty);
106
107          impl<'de> #export::DeserializeSeed<'de> for #visname<'_> {
108            type Value = ();
109
110            fn deserialize<D>(self, deserializer: D) -> #export::Result<Self::Value, D::Error>
111            where
112                D: #export::Deserializer<'de>
113            {
114              *self.0 = #de_fn(deserializer)?;
115              Ok(())
116            }
117          }
118
119          #visname(&mut (self.0).#name)
120        }}
121      } else {
122        quote! { #export::DeserializeWrapper(&mut (self.0).#name) }
123      }
124    }
125  }
126
127  fn map_de(&self, export: &syn::Path) -> TokenStream {
128    let wrapper = self.build_de_wrapper(export);
129    quote! { map.next_value_seed(#wrapper)? }
130  }
131
132  fn seq_de(&self, export: &syn::Path) -> TokenStream {
133    let wrapper = self.build_de_wrapper(export);
134    quote! {
135      if seq.next_element_seed(#wrapper)?.is_none() {
136        return Ok(())
137      }
138    }
139  }
140
141  fn source_name(&self) -> syn::LitStr {
142    match &self.srcname {
143      Some(name) => syn::LitStr::new(&name, self.name.span()),
144      None => syn::LitStr::new(&self.name.to_string(), self.name.span()),
145    }
146  }
147}
148
149fn impl_generic(
150  mut input: DeriveInput,
151  real_crate_name: Ident,
152  fields: Vec<FieldInfo>,
153  fields_numbered: bool,
154) -> syn::Result<TokenStream> {
155  let struct_name = &input.ident;
156  let deserializer = Ident::new("__deserializer", Span::call_site());
157  let crate_name = Ident::new(&("_".to_owned() + CRATE_NAME), Span::call_site());
158  let export = syn::parse_quote! { #crate_name::export };
159
160  let field_enums = fields
161    .iter()
162    .map(|field| &field.enum_value)
163    .cloned()
164    .collect::<Vec<_>>();
165  let field_enums = &field_enums;
166  let field_enums_copy1 = field_enums;
167  let field_enums_copy2 = field_enums;
168  let field_names = fields.iter().map(|x| x.source_name()).collect::<Vec<_>>();
169  let field_names = &field_names;
170  let indices = (0usize..fields.len()).collect::<Vec<_>>();
171  let indices_u64 = indices.iter().map(|x| *x as u64);
172
173  let missing_field_error_str = syn::LitStr::new(
174    &format!("field index between 0 <= i < {}", fields.len()),
175    Span::call_site(),
176  );
177
178  let visit_str_and_bytes_impl = if !fields_numbered {
179    let names_str = &field_names;
180    let names_bytes = field_names
181      .iter()
182      .map(|x| syn::LitByteStr::new(x.value().as_bytes(), x.span()))
183      .collect::<Vec<_>>();
184
185    quote! {
186      fn visit_str<E>(self, value: &str) -> #export::Result<Self::Value, E>
187      where
188        E: #export::Error
189      {
190        #export::Ok(match value {
191          #( #names_str => __Field::#field_enums, )*
192          _ => __Field::__ignore
193        })
194      }
195
196      fn visit_bytes<E>(self, value: &[u8]) -> #export::Result<Self::Value, E>
197      where
198        E: #export::Error
199      {
200        #export::Ok(match value {
201          #( #names_bytes => __Field::#field_enums, )*
202          _ => __Field::__ignore
203        })
204      }
205    }
206  } else {
207    quote! {}
208  };
209
210  let map_de_entries = fields
211    .iter()
212    .map(|field| field.map_de(&export))
213    .collect::<Vec<_>>();
214
215  let visit_seq_entries = fields
216    .iter()
217    .map(|field| field.seq_de(&export))
218    .collect::<Vec<_>>();
219
220  if !input.generics.params.is_empty() {
221    let where_clause = input.generics.make_where_clause();
222
223    for field in fields.iter() {
224      let ty = &field.ty;
225
226      if field.passthrough {
227        where_clause.predicates.push(parse_quote! {
228          #ty: #crate_name::DeserializeOver<'de>
229        });
230      } else {
231        where_clause.predicates.push(parse_quote! {
232          #ty: #crate_name::export::Deserialize<'de>
233        });
234      }
235    }
236  }
237
238  let (_, ty_generics, where_clause) = input.generics.split_for_impl();
239  let impl_generics = &input.generics.params;
240
241  let visitor_params = impl_generics
242    .iter()
243    .map(|param| match param {
244      GenericParam::Type(ty) => ty.ident.to_token_stream(),
245      GenericParam::Lifetime(lt) => lt.lifetime.to_token_stream(),
246      GenericParam::Const(cnst) => cnst.ident.to_token_stream(),
247    })
248    .collect::<Punctuated<_, Token![,]>>();
249
250  let inner = quote! {
251    #[allow(unknown_lints)]
252    #[allow(rust_2018_idioms)]
253    extern crate #real_crate_name as #crate_name;
254
255    #[automatically_derived]
256    impl<'de, #impl_generics> #crate_name::DeserializeOver<'de> for #struct_name #ty_generics
257      #where_clause
258    {
259      fn deserialize_over<D>(&mut self, #deserializer: D) -> #export::Result<(), D::Error>
260      where
261        D: #export::Deserializer<'de>
262      {
263        #[allow(non_camel_case_types)]
264        enum __Field {
265          #( #field_enums, )*
266          __ignore
267        }
268        impl<'de> #export::Deserialize<'de> for __Field {
269          fn deserialize<D>(#deserializer: D) -> #export::Result<Self, D::Error>
270          where
271            D: #export::Deserializer<'de>
272          {
273            #export::Deserializer::deserialize_identifier(#deserializer, __FieldVisitor)
274          }
275        }
276
277        struct __FieldVisitor;
278        impl<'de> #export::Visitor<'de> for __FieldVisitor {
279          type Value = __Field;
280
281          fn expecting(&self, fmt: &mut #export::fmt::Formatter) -> #export::fmt::Result {
282            #export::fmt::Formatter::write_str(fmt, "field identifier")
283          }
284
285          fn visit_u64<E>(self, value: u64) -> #export::Result<Self::Value, E>
286          where
287            E: #export::Error
288          {
289            use #export::{Ok, Err};
290
291            Ok(match value {
292              #( #indices_u64 => __Field::#field_enums, )*
293              _ => return Err(#export::Error::invalid_value(
294                #export::Unexpected::Unsigned(value),
295                &#missing_field_error_str
296              ))
297            })
298          }
299
300          #visit_str_and_bytes_impl
301        }
302
303        struct __Visitor<'a, #impl_generics>(pub &'a mut #struct_name #ty_generics);
304
305        impl<'a, 'de, #impl_generics> #export::Visitor<'de> for __Visitor<'a, #visitor_params>
306          #where_clause
307        {
308          type Value = ();
309
310          fn expecting(&self, fmt: &mut #export::fmt::Formatter) -> #export::fmt::Result {
311            #export::fmt::Formatter::write_str(fmt, concat!("struct ", stringify!(#struct_name)))
312          }
313
314          fn visit_seq<A>(self, mut seq: A) -> #export::Result<Self::Value, A::Error>
315          where
316            A: #export::SeqAccess<'de>
317          {
318            use #export::{Some, None};
319
320            #( #visit_seq_entries; )*
321
322            Ok(())
323          }
324
325          fn visit_map<A>(self, mut map: A) -> #export::Result<Self::Value, A::Error>
326          where
327            A: #export::MapAccess<'de>
328          {
329            use #export::{Some, None, Error};
330
331            // State tracking
332            #(
333              let mut #field_enums: bool = false;
334            )*
335
336            while let Some(key) = map.next_key::<__Field>()? {
337              match key {
338                #(
339                  __Field::#field_enums => if #field_enums_copy1 {
340                    return Err(<A::Error as Error>::duplicate_field(stringify!(#field_names)));
341                  } else {
342                    #field_enums_copy2 = true;
343                    #map_de_entries;
344                  }
345                )*
346                _ => (),
347              }
348            }
349
350            Ok(())
351          }
352        }
353
354        const FIELDS: &[&str] = &[
355          #( stringify!(#field_names), )*
356        ];
357
358        #export::Deserializer::deserialize_struct(
359          #deserializer,
360          stringify!(#struct_name),
361          FIELDS,
362          __Visitor(self)
363        )
364      }
365    }
366  };
367
368  let const_name = Ident::new(
369    &format!("_IMPL_DESERIALIZE_OVER_FOR_{}", struct_name),
370    struct_name.span(),
371  );
372
373  Ok(
374    quote! {
375      #[allow(non_upper_case_globals, unused_attributes, unused_qualifications, non_camel_case_types)]
376      const #const_name: () = {
377        #inner
378      };
379    }
380    .into(),
381  )
382}
383
384fn impl_named_fields(
385  input: DeriveInput,
386  crate_name: Ident,
387  fields: FieldsNamed,
388) -> syn::Result<TokenStream> {
389  let fieldinfos = fields
390    .named
391    .iter()
392    .enumerate()
393    .map(|(idx, x)| {
394      let attr = parse_attr(x.attrs.iter())?;
395
396      let name = x.ident.clone().unwrap();
397
398      Ok(FieldInfo {
399        enum_value: Ident::new(&format!("__field{}", idx), name.span()),
400
401        name,
402        ty: x.ty.clone(),
403        passthrough: attr.use_deserialize_over,
404        deserialize_with: attr.deserialize_fn,
405        deserialize_merge_with: attr.deserialize_merge_fn,
406        srcname: attr.rename.map(|x| x.value()),
407      })
408    })
409    .collect::<Result<Vec<_>, syn::Error>>()?;
410
411  return impl_generic(input, crate_name, fieldinfos, false);
412}
413
414fn impl_unnamed_fields(
415  _input: DeriveInput,
416  _crate_name: Ident,
417  _fields: FieldsUnnamed,
418) -> syn::Result<TokenStream> {
419  panic!("Deriving DeserializeInto for tuple structs is not supported");
420}
421
422fn impl_unit(input: DeriveInput, crate_name: Ident) -> syn::Result<TokenStream> {
423  let struct_name = &input.ident;
424
425  Ok(
426    quote! {
427      impl ::#crate_name::DeserializeOver for #struct_name {
428        fn deserialize_over<'de, D>(&mut self, de: D) -> Result<(), D::Error>
429        where
430          D: Deserializer<'de>
431        {
432          Ok(())
433        }
434      }
435    }
436    .into(),
437  )
438}
439
440#[derive(Default)]
441struct ParsedAttr {
442  use_deserialize_over: bool,
443  deserialize_fn: Option<Path>,
444  deserialize_merge_fn: Option<Path>,
445  rename: Option<syn::LitStr>,
446}
447
448fn parse_attr<'a, I>(attrs: I) -> syn::Result<ParsedAttr>
449where
450  I: Iterator<Item = &'a Attribute>,
451{
452  use syn::spanned::Spanned;
453
454  let mut result = ParsedAttr::default();
455
456  for attr in attrs.into_iter() {
457    if attr.path.is_ident("deserialize_over") {
458      if !attr.tokens.is_empty() {
459        return Err(syn::Error::new_spanned(
460          attr.path.to_token_stream(),
461          "deserialize_over attribute should not have any arguments",
462        ));
463      }
464
465      result.use_deserialize_over = true;
466    } else if attr.path.is_ident("serde") {
467      let body: self::attr::SerdeAttrBody = syn::parse2(attr.tokens.clone())?;
468      let mut seen = HashSet::new();
469
470      for opt in body.attrs.iter() {
471        let ident = opt.ident().to_string();
472
473        // Put serde arguments that we support here so that we can error out on
474        // unsupported ones.
475        match &*ident {
476          "with" | "deserialize_with" | "serialize_with" => (),
477          "rename" | "serialize" | "deserialize" => (),
478          // #[serde(default)] is ignored since we already have values for all fields
479          "default" => (),
480          name => {
481            return Err(syn::Error::new(
482              opt.span(),
483              &format!(
484                r#"#[serde({}{}) is not supported by the DeserializeOver derive macro."#,
485                name,
486                if opt.is_flag() { r#" = "...""# } else { "" }
487              ),
488            ))
489          }
490        }
491
492        if !seen.insert(ident) {
493          return Err(syn::Error::new_spanned(
494            opt,
495            &format!(
496              "Option `{}` cannot be specified multiple times",
497              opt.ident()
498            ),
499          ));
500        }
501      }
502
503      if let Some(lit) = body.get("with") {
504        result.deserialize_fn = Some(
505          syn::parse_str(&(lit.value() + "::deserialize"))
506            .map_err(|e| syn::Error::new_spanned(lit, e))?,
507        );
508        result.deserialize_merge_fn = Some(
509          syn::parse_str(&(lit.value() + "::deserialize_over"))
510            .map_err(|e| syn::Error::new_spanned(lit, e))?,
511        );
512      }
513
514      if let Some(lit) = body.get("deserialize_with") {
515        if result.deserialize_fn.is_some() {
516          return Err(syn::Error::new(
517            body.span_for("deserialize_with"),
518            "Cannot specify both `with` and `deserialize_with`",
519          ));
520        }
521
522        result.deserialize_fn =
523          Some(syn::parse_str(&lit.value()).map_err(|e| syn::Error::new_spanned(lit, e))?);
524      }
525
526      if let Some(lit) = body.get("rename") {
527        result.rename = Some(lit.clone());
528      }
529
530      if let Some(lit) = body.get("deserialize") {
531        if result.rename.is_some() {
532          return Err(syn::Error::new(
533            body.span_for("deserialize"),
534            "Cannot specify both `rename` and `deserialize`",
535          ));
536        }
537
538        result.rename = Some(lit.clone());
539      }
540    }
541  }
542
543  Ok(result)
544}