1#![recursion_limit = "256"]
4
5extern crate proc_macro;
6
7mod attr;
8
9use std::collections::HashSet;
10
11use proc_macro2::{Span, TokenStream};
13use proc_macro_crate::{crate_name, FoundCrate};
14use quote::{quote, ToTokens};
15use syn::{
16 parse_macro_input, parse_quote, punctuated::Punctuated, Attribute, Data, DeriveInput, Fields,
17 FieldsNamed, FieldsUnnamed, GenericParam, Ident, Path, Token, Type,
18};
19
20const CRATE_NAME: &str = "serde_deserialize_over";
21
22#[proc_macro_derive(DeserializeOver, attributes(deserialize_over, serde))]
24pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
25 let input = parse_macro_input!(input as DeriveInput);
26 let crate_name =
27 crate_name("serde-deserialize-over").unwrap_or(FoundCrate::Name(CRATE_NAME.to_string()));
28 let crate_name = match crate_name {
29 FoundCrate::Name(name) => Ident::new(&name, Span::call_site()),
30 FoundCrate::Itself => Ident::new(CRATE_NAME, Span::call_site()),
31 };
32
33 let data = match input.data {
34 Data::Struct(ref data) => data.clone(),
35 Data::Enum(_) => panic!("`DeserializeOver` cannot be automatically derived for enums"),
36 Data::Union(_) => panic!("`DeserializeOver` cannot be automatically derived for unions"),
37 };
38
39 let res = match data.fields {
40 Fields::Named(fields) => impl_named_fields(input, crate_name, fields),
41 Fields::Unnamed(fields) => impl_unnamed_fields(input, crate_name, fields),
42 Fields::Unit => impl_unit(input, crate_name),
43 };
44
45 match res {
46 Ok(res) => {
47 res.into()
49 }
50 Err(e) => e.to_compile_error().into(),
51 }
52}
53
54#[derive(Clone)]
55struct FieldInfo {
56 name: Ident,
57 ty: Type,
58 passthrough: bool,
59 deserialize_with: Option<Path>,
60 deserialize_merge_with: Option<Path>,
61
62 srcname: Option<String>,
63 enum_value: Ident,
64}
65
66impl FieldInfo {
67 fn build_de_wrapper(&self, export: &syn::Path) -> TokenStream {
68 let Self { name, ty, .. } = self;
69 let visname = Ident::new(&format!("FieldWrapper{}", self.enum_value), name.span());
70 let lt = syn::Lifetime::new("'_serde_deserialize_over_a", Span::call_site());
71
72 if self.passthrough {
73 if let Some(merge_fn) = &self.deserialize_merge_with {
74 quote::quote! {{
75 struct #visname<#lt>(&#lt mut #ty);
76
77 impl<'de> #export::DeserializeSeed<'de> for #visname<'_> {
78 type Value = ();
79
80 fn deserialize<D>(self, deserializer: D) -> #export::Result<Self::Value, D::Error>
81 where
82 D: #export::Deserializer<'de>
83 {
84 #merge_fn(deserializer, self.0)
85 }
86 }
87
88 #visname(&mut (self.0).#name)
89 }}
90 } else {
91 if self.deserialize_with.is_some() {
92 return quote::quote_spanned! {
93 name.span() => {
94 compile_error!(r#"Field uses both $[serde(deserialize_with)] and #[deserializer_over]. Use #[serde(with = "...")] so that the DeserializeOver derive will use a custom deserialize function."#);
95 unreachable!()
96 }
97 };
98 }
99
100 quote! { #export::DeserializeOverWrapper(&mut (self.0).#name) }
101 }
102 } else {
103 if let Some(de_fn) = &self.deserialize_with {
104 quote::quote! {{
105 struct #visname<#lt>(&#lt mut #ty);
106
107 impl<'de> #export::DeserializeSeed<'de> for #visname<'_> {
108 type Value = ();
109
110 fn deserialize<D>(self, deserializer: D) -> #export::Result<Self::Value, D::Error>
111 where
112 D: #export::Deserializer<'de>
113 {
114 *self.0 = #de_fn(deserializer)?;
115 Ok(())
116 }
117 }
118
119 #visname(&mut (self.0).#name)
120 }}
121 } else {
122 quote! { #export::DeserializeWrapper(&mut (self.0).#name) }
123 }
124 }
125 }
126
127 fn map_de(&self, export: &syn::Path) -> TokenStream {
128 let wrapper = self.build_de_wrapper(export);
129 quote! { map.next_value_seed(#wrapper)? }
130 }
131
132 fn seq_de(&self, export: &syn::Path) -> TokenStream {
133 let wrapper = self.build_de_wrapper(export);
134 quote! {
135 if seq.next_element_seed(#wrapper)?.is_none() {
136 return Ok(())
137 }
138 }
139 }
140
141 fn source_name(&self) -> syn::LitStr {
142 match &self.srcname {
143 Some(name) => syn::LitStr::new(&name, self.name.span()),
144 None => syn::LitStr::new(&self.name.to_string(), self.name.span()),
145 }
146 }
147}
148
149fn impl_generic(
150 mut input: DeriveInput,
151 real_crate_name: Ident,
152 fields: Vec<FieldInfo>,
153 fields_numbered: bool,
154) -> syn::Result<TokenStream> {
155 let struct_name = &input.ident;
156 let deserializer = Ident::new("__deserializer", Span::call_site());
157 let crate_name = Ident::new(&("_".to_owned() + CRATE_NAME), Span::call_site());
158 let export = syn::parse_quote! { #crate_name::export };
159
160 let field_enums = fields
161 .iter()
162 .map(|field| &field.enum_value)
163 .cloned()
164 .collect::<Vec<_>>();
165 let field_enums = &field_enums;
166 let field_enums_copy1 = field_enums;
167 let field_enums_copy2 = field_enums;
168 let field_names = fields.iter().map(|x| x.source_name()).collect::<Vec<_>>();
169 let field_names = &field_names;
170 let indices = (0usize..fields.len()).collect::<Vec<_>>();
171 let indices_u64 = indices.iter().map(|x| *x as u64);
172
173 let missing_field_error_str = syn::LitStr::new(
174 &format!("field index between 0 <= i < {}", fields.len()),
175 Span::call_site(),
176 );
177
178 let visit_str_and_bytes_impl = if !fields_numbered {
179 let names_str = &field_names;
180 let names_bytes = field_names
181 .iter()
182 .map(|x| syn::LitByteStr::new(x.value().as_bytes(), x.span()))
183 .collect::<Vec<_>>();
184
185 quote! {
186 fn visit_str<E>(self, value: &str) -> #export::Result<Self::Value, E>
187 where
188 E: #export::Error
189 {
190 #export::Ok(match value {
191 #( #names_str => __Field::#field_enums, )*
192 _ => __Field::__ignore
193 })
194 }
195
196 fn visit_bytes<E>(self, value: &[u8]) -> #export::Result<Self::Value, E>
197 where
198 E: #export::Error
199 {
200 #export::Ok(match value {
201 #( #names_bytes => __Field::#field_enums, )*
202 _ => __Field::__ignore
203 })
204 }
205 }
206 } else {
207 quote! {}
208 };
209
210 let map_de_entries = fields
211 .iter()
212 .map(|field| field.map_de(&export))
213 .collect::<Vec<_>>();
214
215 let visit_seq_entries = fields
216 .iter()
217 .map(|field| field.seq_de(&export))
218 .collect::<Vec<_>>();
219
220 if !input.generics.params.is_empty() {
221 let where_clause = input.generics.make_where_clause();
222
223 for field in fields.iter() {
224 let ty = &field.ty;
225
226 if field.passthrough {
227 where_clause.predicates.push(parse_quote! {
228 #ty: #crate_name::DeserializeOver<'de>
229 });
230 } else {
231 where_clause.predicates.push(parse_quote! {
232 #ty: #crate_name::export::Deserialize<'de>
233 });
234 }
235 }
236 }
237
238 let (_, ty_generics, where_clause) = input.generics.split_for_impl();
239 let impl_generics = &input.generics.params;
240
241 let visitor_params = impl_generics
242 .iter()
243 .map(|param| match param {
244 GenericParam::Type(ty) => ty.ident.to_token_stream(),
245 GenericParam::Lifetime(lt) => lt.lifetime.to_token_stream(),
246 GenericParam::Const(cnst) => cnst.ident.to_token_stream(),
247 })
248 .collect::<Punctuated<_, Token![,]>>();
249
250 let inner = quote! {
251 #[allow(unknown_lints)]
252 #[allow(rust_2018_idioms)]
253 extern crate #real_crate_name as #crate_name;
254
255 #[automatically_derived]
256 impl<'de, #impl_generics> #crate_name::DeserializeOver<'de> for #struct_name #ty_generics
257 #where_clause
258 {
259 fn deserialize_over<D>(&mut self, #deserializer: D) -> #export::Result<(), D::Error>
260 where
261 D: #export::Deserializer<'de>
262 {
263 #[allow(non_camel_case_types)]
264 enum __Field {
265 #( #field_enums, )*
266 __ignore
267 }
268 impl<'de> #export::Deserialize<'de> for __Field {
269 fn deserialize<D>(#deserializer: D) -> #export::Result<Self, D::Error>
270 where
271 D: #export::Deserializer<'de>
272 {
273 #export::Deserializer::deserialize_identifier(#deserializer, __FieldVisitor)
274 }
275 }
276
277 struct __FieldVisitor;
278 impl<'de> #export::Visitor<'de> for __FieldVisitor {
279 type Value = __Field;
280
281 fn expecting(&self, fmt: &mut #export::fmt::Formatter) -> #export::fmt::Result {
282 #export::fmt::Formatter::write_str(fmt, "field identifier")
283 }
284
285 fn visit_u64<E>(self, value: u64) -> #export::Result<Self::Value, E>
286 where
287 E: #export::Error
288 {
289 use #export::{Ok, Err};
290
291 Ok(match value {
292 #( #indices_u64 => __Field::#field_enums, )*
293 _ => return Err(#export::Error::invalid_value(
294 #export::Unexpected::Unsigned(value),
295 &#missing_field_error_str
296 ))
297 })
298 }
299
300 #visit_str_and_bytes_impl
301 }
302
303 struct __Visitor<'a, #impl_generics>(pub &'a mut #struct_name #ty_generics);
304
305 impl<'a, 'de, #impl_generics> #export::Visitor<'de> for __Visitor<'a, #visitor_params>
306 #where_clause
307 {
308 type Value = ();
309
310 fn expecting(&self, fmt: &mut #export::fmt::Formatter) -> #export::fmt::Result {
311 #export::fmt::Formatter::write_str(fmt, concat!("struct ", stringify!(#struct_name)))
312 }
313
314 fn visit_seq<A>(self, mut seq: A) -> #export::Result<Self::Value, A::Error>
315 where
316 A: #export::SeqAccess<'de>
317 {
318 use #export::{Some, None};
319
320 #( #visit_seq_entries; )*
321
322 Ok(())
323 }
324
325 fn visit_map<A>(self, mut map: A) -> #export::Result<Self::Value, A::Error>
326 where
327 A: #export::MapAccess<'de>
328 {
329 use #export::{Some, None, Error};
330
331 #(
333 let mut #field_enums: bool = false;
334 )*
335
336 while let Some(key) = map.next_key::<__Field>()? {
337 match key {
338 #(
339 __Field::#field_enums => if #field_enums_copy1 {
340 return Err(<A::Error as Error>::duplicate_field(stringify!(#field_names)));
341 } else {
342 #field_enums_copy2 = true;
343 #map_de_entries;
344 }
345 )*
346 _ => (),
347 }
348 }
349
350 Ok(())
351 }
352 }
353
354 const FIELDS: &[&str] = &[
355 #( stringify!(#field_names), )*
356 ];
357
358 #export::Deserializer::deserialize_struct(
359 #deserializer,
360 stringify!(#struct_name),
361 FIELDS,
362 __Visitor(self)
363 )
364 }
365 }
366 };
367
368 let const_name = Ident::new(
369 &format!("_IMPL_DESERIALIZE_OVER_FOR_{}", struct_name),
370 struct_name.span(),
371 );
372
373 Ok(
374 quote! {
375 #[allow(non_upper_case_globals, unused_attributes, unused_qualifications, non_camel_case_types)]
376 const #const_name: () = {
377 #inner
378 };
379 }
380 .into(),
381 )
382}
383
384fn impl_named_fields(
385 input: DeriveInput,
386 crate_name: Ident,
387 fields: FieldsNamed,
388) -> syn::Result<TokenStream> {
389 let fieldinfos = fields
390 .named
391 .iter()
392 .enumerate()
393 .map(|(idx, x)| {
394 let attr = parse_attr(x.attrs.iter())?;
395
396 let name = x.ident.clone().unwrap();
397
398 Ok(FieldInfo {
399 enum_value: Ident::new(&format!("__field{}", idx), name.span()),
400
401 name,
402 ty: x.ty.clone(),
403 passthrough: attr.use_deserialize_over,
404 deserialize_with: attr.deserialize_fn,
405 deserialize_merge_with: attr.deserialize_merge_fn,
406 srcname: attr.rename.map(|x| x.value()),
407 })
408 })
409 .collect::<Result<Vec<_>, syn::Error>>()?;
410
411 return impl_generic(input, crate_name, fieldinfos, false);
412}
413
414fn impl_unnamed_fields(
415 _input: DeriveInput,
416 _crate_name: Ident,
417 _fields: FieldsUnnamed,
418) -> syn::Result<TokenStream> {
419 panic!("Deriving DeserializeInto for tuple structs is not supported");
420}
421
422fn impl_unit(input: DeriveInput, crate_name: Ident) -> syn::Result<TokenStream> {
423 let struct_name = &input.ident;
424
425 Ok(
426 quote! {
427 impl ::#crate_name::DeserializeOver for #struct_name {
428 fn deserialize_over<'de, D>(&mut self, de: D) -> Result<(), D::Error>
429 where
430 D: Deserializer<'de>
431 {
432 Ok(())
433 }
434 }
435 }
436 .into(),
437 )
438}
439
440#[derive(Default)]
441struct ParsedAttr {
442 use_deserialize_over: bool,
443 deserialize_fn: Option<Path>,
444 deserialize_merge_fn: Option<Path>,
445 rename: Option<syn::LitStr>,
446}
447
448fn parse_attr<'a, I>(attrs: I) -> syn::Result<ParsedAttr>
449where
450 I: Iterator<Item = &'a Attribute>,
451{
452 use syn::spanned::Spanned;
453
454 let mut result = ParsedAttr::default();
455
456 for attr in attrs.into_iter() {
457 if attr.path.is_ident("deserialize_over") {
458 if !attr.tokens.is_empty() {
459 return Err(syn::Error::new_spanned(
460 attr.path.to_token_stream(),
461 "deserialize_over attribute should not have any arguments",
462 ));
463 }
464
465 result.use_deserialize_over = true;
466 } else if attr.path.is_ident("serde") {
467 let body: self::attr::SerdeAttrBody = syn::parse2(attr.tokens.clone())?;
468 let mut seen = HashSet::new();
469
470 for opt in body.attrs.iter() {
471 let ident = opt.ident().to_string();
472
473 match &*ident {
476 "with" | "deserialize_with" | "serialize_with" => (),
477 "rename" | "serialize" | "deserialize" => (),
478 "default" => (),
480 name => {
481 return Err(syn::Error::new(
482 opt.span(),
483 &format!(
484 r#"#[serde({}{}) is not supported by the DeserializeOver derive macro."#,
485 name,
486 if opt.is_flag() { r#" = "...""# } else { "" }
487 ),
488 ))
489 }
490 }
491
492 if !seen.insert(ident) {
493 return Err(syn::Error::new_spanned(
494 opt,
495 &format!(
496 "Option `{}` cannot be specified multiple times",
497 opt.ident()
498 ),
499 ));
500 }
501 }
502
503 if let Some(lit) = body.get("with") {
504 result.deserialize_fn = Some(
505 syn::parse_str(&(lit.value() + "::deserialize"))
506 .map_err(|e| syn::Error::new_spanned(lit, e))?,
507 );
508 result.deserialize_merge_fn = Some(
509 syn::parse_str(&(lit.value() + "::deserialize_over"))
510 .map_err(|e| syn::Error::new_spanned(lit, e))?,
511 );
512 }
513
514 if let Some(lit) = body.get("deserialize_with") {
515 if result.deserialize_fn.is_some() {
516 return Err(syn::Error::new(
517 body.span_for("deserialize_with"),
518 "Cannot specify both `with` and `deserialize_with`",
519 ));
520 }
521
522 result.deserialize_fn =
523 Some(syn::parse_str(&lit.value()).map_err(|e| syn::Error::new_spanned(lit, e))?);
524 }
525
526 if let Some(lit) = body.get("rename") {
527 result.rename = Some(lit.clone());
528 }
529
530 if let Some(lit) = body.get("deserialize") {
531 if result.rename.is_some() {
532 return Err(syn::Error::new(
533 body.span_for("deserialize"),
534 "Cannot specify both `rename` and `deserialize`",
535 ));
536 }
537
538 result.rename = Some(lit.clone());
539 }
540 }
541 }
542
543 Ok(result)
544}