1extern crate proc_macro;
143
144mod parse;
145
146use parse::Skip;
147use proc_macro::TokenStream;
148use proc_macro2::{Ident, Span};
149use quote::{format_ident, quote, quote_spanned};
150use syn::{
151 parse_macro_input, ImplGenerics, Lifetime, LifetimeParam, TypeGenerics, TypeParamBound,
152 WhereClause,
153};
154
155use crate::parse::Input;
156
157fn serialize_fields(
158 fields: &[parse::Field],
159 offset: usize,
160 impl_generics_serialize: ImplGenerics<'_>,
161 ty_generics_serialize: TypeGenerics<'_>,
162 ty_generics: &TypeGenerics<'_>,
163 where_clause: Option<&WhereClause>,
164 ident: &Ident,
165) -> Vec<proc_macro2::TokenStream> {
166 fields
167 .iter()
168 .filter(|field| !field.skip_serializing_if.is_always())
169 .map(|field| {
170 let index = field.index.expect("index must be set for fields that are not skipped") + offset;
172 let member = &field.member;
173 let serialize_member = match &field.serialize_with {
174 None => quote!(&self.#member),
175 Some(f) => {
176 let ty = &field.ty;
177 quote!({
178 struct __InternalSerdeIndexedSerializeWith #impl_generics_serialize {
179 value: &'__serde_indexed_lifetime #ty,
180 phantom: ::core::marker::PhantomData<#ident #ty_generics>,
181 }
182
183 impl #impl_generics_serialize serde::Serialize for __InternalSerdeIndexedSerializeWith #ty_generics_serialize #where_clause {
184 fn serialize<__S>(
185 &self,
186 __s: __S,
187 ) -> ::core::result::Result<__S::Ok, __S::Error>
188 where
189 __S: serde::Serializer,
190 {
191 #f(self.value, __s)
192 }
193 }
194
195 &__InternalSerdeIndexedSerializeWith { value: &self.#member, phantom: ::core::marker::PhantomData::<#ident #ty_generics> }
196 })
197 }
198 };
199
200 match &field.skip_serializing_if {
202 Skip::If(path) => quote! {
203 if !#path(&self.#member) {
204 map.serialize_entry(&#index, #serialize_member)?;
205 }
206 },
207 Skip::Always => unreachable!(),
208 Skip::Never => quote! {
209 map.serialize_entry(&#index, #serialize_member)?;
210 },
211 }
212 })
213 .collect()
214}
215
216fn count_serialized_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
217 fields
218 .iter()
219 .map(|field| {
220 let member = &field.member;
222 match &field.skip_serializing_if {
223 Skip::If(path) => {
224 quote! { if #path(&self.#member) { 0 } else { 1 } }
225 }
226 Skip::Always => quote! { 0 },
227
228 Skip::Never => {
229 quote! { 1 }
230 }
231 }
232 })
233 .collect()
234}
235
236#[proc_macro_derive(SerializeIndexed, attributes(serde, serde_indexed))]
237pub fn derive_serialize(input: TokenStream) -> TokenStream {
238 let input = parse_macro_input!(input as Input);
239 let ident = input.ident;
240 let num_fields = count_serialized_fields(&input.fields);
241 let (_, ty_generics, where_clause) = input.generics.split_for_impl();
242 let mut generics_cl = input.generics.clone();
243 generics_cl.type_params_mut().for_each(|t| {
244 t.bounds
245 .push_value(TypeParamBound::Verbatim(quote!(serde::Serialize)));
246 });
247 let (impl_generics, _, _) = generics_cl.split_for_impl();
248
249 let mut generics_cl2 = generics_cl.clone();
250
251 generics_cl2
252 .params
253 .push(syn::GenericParam::Lifetime(LifetimeParam::new(
254 Lifetime::new("'__serde_indexed_lifetime", Span::call_site()),
255 )));
256
257 let (impl_generics_serialize, ty_generics_serialize, _) = generics_cl2.split_for_impl();
258
259 let serialize_fields = serialize_fields(
260 &input.fields,
261 input.attrs.offset,
262 impl_generics_serialize,
263 ty_generics_serialize,
264 &ty_generics,
265 where_clause,
266 &ident,
267 );
268
269 TokenStream::from(quote! {
270 #[automatically_derived]
271 impl #impl_generics serde::Serialize for #ident #ty_generics #where_clause {
272 fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
273 where
274 S: serde::Serializer
275 {
276 use serde::ser::SerializeMap;
277 let num_fields = 0 #( + #num_fields)*;
278 let mut map = serializer.serialize_map(Some(num_fields))?;
279
280 #(#serialize_fields)*
281
282 map.end()
283 }
284 }
285 })
286}
287
288fn none_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
289 fields
290 .iter()
291 .filter(|f| !f.skip_serializing_if.is_always())
292 .map(|field| {
293 let ident = format_ident!("{}", &field.label);
294 let span = field.original_span;
295 quote_spanned! { span =>
296 let mut #ident = None;
297 }
298 })
299 .collect()
300}
301
302fn unwrap_expected_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
303 fields
304 .iter()
305 .map(|field| {
306 let label = field.label.clone();
307 let ident = format_ident!("{}", &field.label);
308 let span = field.original_span;
309 match field.skip_serializing_if {
310 Skip::Never => quote! {
311 let #ident = #ident.ok_or_else(|| serde::de::Error::missing_field(#label))?;
312 },
313 Skip::If(_) => quote_spanned! { span =>
314 let #ident = #ident.unwrap_or_default();
315 },
316 Skip::Always => quote! {
317 let #ident = ::core::default::Default::default();
318 },
319 }
320 })
321 .collect()
322}
323
324fn match_fields(
325 fields: &[parse::Field],
326 offset: usize,
327 impl_generics_with_de: &ImplGenerics<'_>,
328 ty_generics: &TypeGenerics<'_>,
329 ty_generics_with_de: &TypeGenerics<'_>,
330 where_clause: Option<&WhereClause>,
331 struct_ident: &Ident,
332) -> Vec<proc_macro2::TokenStream> {
333 fields
334 .iter()
335 .filter(|f| !f.skip_serializing_if.is_always())
336 .map(|field| {
337 let label = field.label.clone();
338 let ident = format_ident!("{}", &field.label);
339 let index = field.index.expect("index must be set for fields that are not skipped") + offset;
341 let span = field.original_span;
342
343 let next_value = match &field.deserialize_with {
344 Some(f) => {
345 let ty = &field.ty;
346 quote_spanned!(span => {
347 struct __InternalSerdeIndexedDeserializeWith #impl_generics_with_de {
348 value: #ty,
349 phantom: ::core::marker::PhantomData<#struct_ident #ty_generics>,
350 lifetime: ::core::marker::PhantomData<&'de ()>,
351 }
352 impl #impl_generics_with_de serde::Deserialize<'de> for __InternalSerdeIndexedDeserializeWith #ty_generics_with_de #where_clause {
353 fn deserialize<__D>(
354 __deserializer: __D,
355 ) -> Result<Self, __D::Error>
356 where
357 __D: serde::Deserializer<'de>,
358 {
359
360 Ok(__InternalSerdeIndexedDeserializeWith {
361 value: #f(__deserializer)?,
362 phantom: ::core::marker::PhantomData,
363 lifetime: ::core::marker::PhantomData,
364 })
365 }
366 }
367
368 let __InternalSerdeIndexedDeserializeWith { value, lifetime: _, phantom: _ } = map.next_value()?;
369 value
370 }
371 )
372 }
373 None => quote_spanned!(span => map.next_value()?),
374 };
375
376 quote_spanned!{ span =>
377 #index => {
378 if #ident.is_some() {
379 return Err(serde::de::Error::duplicate_field(#label));
380 }
381 let next_value = #next_value;
382 #ident = Some(next_value);
383 },
384 }
385 })
386 .collect()
387}
388
389fn all_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
390 fields
391 .iter()
392 .map(|field| {
393 let ident = format_ident!("{}", &field.label);
394 let span = field.original_span;
395 quote_spanned! { span =>
396 #ident
397 }
398 })
399 .collect()
400}
401
402#[proc_macro_derive(DeserializeIndexed, attributes(serde, serde_indexed))]
403pub fn derive_deserialize(input: TokenStream) -> TokenStream {
404 let input = parse_macro_input!(input as Input);
405 let ident = input.ident;
406 let none_fields = none_fields(&input.fields);
407 let unwrap_expected_fields = unwrap_expected_fields(&input.fields);
408 let all_fields = all_fields(&input.fields);
409
410 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
411
412 let mut generics_cl = input.generics.clone();
413 generics_cl.params.insert(
414 0,
415 syn::GenericParam::Lifetime(LifetimeParam {
416 attrs: Vec::new(),
417 lifetime: Lifetime {
418 apostrophe: Span::call_site(),
419 ident: Ident::new("de", Span::call_site()),
420 },
421 colon_token: None,
422 bounds: input
423 .generics
424 .lifetimes()
425 .map(|l| l.lifetime.clone())
426 .collect(),
427 }),
428 );
429 generics_cl.type_params_mut().for_each(|t| {
430 t.bounds
431 .push_value(TypeParamBound::Verbatim(quote!(serde::Deserialize<'de>)));
432 });
433
434 let (impl_generics_with_de, ty_generics_with_de, _) = generics_cl.split_for_impl();
435
436 let match_fields = match_fields(
437 &input.fields,
438 input.attrs.offset,
439 &impl_generics_with_de,
440 &ty_generics,
441 &ty_generics_with_de,
442 where_clause,
443 &ident,
444 );
445
446 let the_loop = if !input.fields.is_empty() {
447 quote! {
452 while let Some(__serde_indexed_internal_key) = map.next_key()? {
453 match __serde_indexed_internal_key {
454 #(#match_fields)*
455 _ => {
456 let _ = map.next_value::<serde::de::IgnoredAny>()?;
458 }
459 }
460 }
461 }
462 } else {
463 quote! {}
464 };
465
466 let res = quote! {
467 #[automatically_derived]
468 impl #impl_generics_with_de serde::Deserialize<'de> for #ident #ty_generics #where_clause {
469 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
470 where
471 D: serde::Deserializer<'de>,
472 {
473 struct IndexedVisitor #impl_generics (core::marker::PhantomData<#ident #ty_generics>);
474
475 impl #impl_generics_with_de serde::de::Visitor<'de> for IndexedVisitor #ty_generics {
476 type Value = #ident #ty_generics;
477
478 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
479 formatter.write_str(stringify!(#ident))
480 }
481
482 fn visit_map<V>(self, mut map: V) -> core::result::Result<Self::Value, V::Error>
483 where
484 V: serde::de::MapAccess<'de>,
485 {
486 #(#none_fields)*
487
488 #the_loop
489
490 #(#unwrap_expected_fields)*
491
492 Ok(#ident { #(#all_fields),* })
493 }
494 }
495
496 deserializer.deserialize_map(IndexedVisitor(Default::default()))
497 }
498 }
499 };
500 TokenStream::from(res)
501}