serde_indexed/
lib.rs

1/*! Derivation of [`Serialize`][serialize] and [`Deserialize`][deserialize] that replaces struct keys with numerical indices.
2
3### Usage example
4
5#### Struct attributes
6
7- `auto_index`: Automatically assign indices to the fields based on the order in the source code.  It is recommended to instead use the `index` attribute for all fields to explicitly assign indices.
8- `offset = ?`: If `auto_index` is set, use the given index for the first field instead of starting with zero.
9
10### Field attributes
11
12- `index = ?`: Set the index for this field to the given field.  This attribute is required unless `auto_index` is set.  It cannot be used together with `auto_index`.
13- `skip`: Never serialize or deserialize this field.  This field still increases the assigned index if `auto_index` is used.
14- `skip(no_increment)`: Never serialize or deserialize this field and don’t increment the assigned index for this field if used together with the `auto_index` attribute.
15
16`serde-indexed` also supports these `serde` attributes:
17- [`deserialize_with`][deserialize-with]
18- [`serialize_with`][serialize-with]
19- [`skip_serializing_if`][skip-serializing-if]
20- [`with`][with]
21
22### Generated code example
23`cargo expand --test basics` exercises the macros using [`serde_cbor`][serde-cbor].
24
25### Examples
26
27Explicit index assignment:
28
29```
30use serde_indexed::{DeserializeIndexed, SerializeIndexed};
31
32#[derive(Clone, Debug, PartialEq, SerializeIndexed, DeserializeIndexed)]
33pub struct SomeKeys {
34    #[serde(index = 1)]
35    pub number: i32,
36    #[serde(index = 2)]
37    pub option: Option<u8>,
38    #[serde(skip)]
39    pub ignored: bool,
40    #[serde(index = 3)]
41    pub bytes: [u8; 7],
42}
43```
44
45Automatic index assignment:
46
47```
48use serde_indexed::{DeserializeIndexed, SerializeIndexed};
49
50#[derive(Clone, Debug, PartialEq, SerializeIndexed, DeserializeIndexed)]
51#[serde(auto_index)]
52pub struct SomeKeys {
53    // index 1
54    pub number: i32,
55    // index 2
56    pub option: Option<u8>,
57    // index 3 (but skipped)
58    #[serde(skip)]
59    pub ignored: bool,
60    // index 4
61    pub bytes: [u8; 7],
62}
63```
64
65Automatic index assignment with `skip(no_increment)`:
66
67```
68use serde_indexed::{DeserializeIndexed, SerializeIndexed};
69
70#[derive(Clone, Debug, PartialEq, SerializeIndexed, DeserializeIndexed)]
71#[serde(auto_index)]
72pub struct SomeKeys {
73    // index 1
74    pub number: i32,
75    // index 2
76    pub option: Option<u8>,
77    #[serde(skip(no_increment))]
78    pub ignored: bool,
79    // index 3
80    pub bytes: [u8; 7],
81}
82```
83
84Automatic index assignment with `offset`:
85
86```
87use serde_indexed::{DeserializeIndexed, SerializeIndexed};
88
89#[derive(Clone, Debug, PartialEq, SerializeIndexed, DeserializeIndexed)]
90#[serde(auto_index, offset = 42)]
91pub struct SomeKeys {
92    // index 42
93    pub number: i32,
94    // index 43
95    pub option: Option<u8>,
96    // index 44
97    pub bytes: [u8; 7],
98}
99```
100
101Skip serializing a field based on a condition with `skip_serializing_if`:
102
103```
104use serde_indexed::{DeserializeIndexed, SerializeIndexed};
105
106#[derive(Clone, Debug, PartialEq, SerializeIndexed, DeserializeIndexed)]
107pub struct SomeKeys {
108    #[serde(index = 1)]
109    pub number: i32,
110    #[serde(index = 2, skip_serializing_if = "Option::is_none")]
111    pub option: Option<u8>,
112    #[serde(index = 3)]
113    pub bytes: [u8; 7],
114}
115```
116
117Change the serialization or deserialization format with `deserialize_with`, `serialize_with` or `with`:
118
119```
120use serde_indexed::{DeserializeIndexed, SerializeIndexed};
121
122#[derive(Clone, Debug, PartialEq, SerializeIndexed, DeserializeIndexed)]
123pub struct SomeKeys<'a> {
124    #[serde(index = 1, serialize_with = "serde_bytes::serialize")]
125    pub one: &'a [u8],
126    #[serde(index = 2, deserialize_with = "serde_bytes::deserialize")]
127    pub two: &'a [u8],
128    #[serde(index = 3, with = "serde_bytes")]
129    pub three: &'a [u8],
130}
131```
132
133[serialize]: https://docs.serde.rs/serde/ser/trait.Serialize.html
134[deserialize]: https://docs.serde.rs/serde/de/trait.Deserialize.html
135[deserialize-with]: https://serde.rs/field-attrs.html#deserialize_with
136[serialize-with]: https://serde.rs/field-attrs.html#serialize_with
137[with]: https://serde.rs/field-attrs.html#with
138[skip-serializing-if]: https://serde.rs/field-attrs.html#skip_serializing_if
139[serde-cbor]: https://docs.rs/serde_cbor
140*/
141
142extern crate proc_macro;
143
144mod parse;
145
146use parse::Skip;
147use proc_macro::TokenStream;
148use proc_macro2::{Ident, Span};
149use quote::{format_ident, quote, quote_spanned};
150use syn::{
151    parse_macro_input, ImplGenerics, Lifetime, LifetimeParam, TypeGenerics, TypeParamBound,
152    WhereClause,
153};
154
155use crate::parse::Input;
156
157fn serialize_fields(
158    fields: &[parse::Field],
159    offset: usize,
160    impl_generics_serialize: ImplGenerics<'_>,
161    ty_generics_serialize: TypeGenerics<'_>,
162    ty_generics: &TypeGenerics<'_>,
163    where_clause: Option<&WhereClause>,
164    ident: &Ident,
165) -> Vec<proc_macro2::TokenStream> {
166    fields
167        .iter()
168        .filter(|field| !field.skip_serializing_if.is_always())
169        .map(|field| {
170            // index should only be none if the field is always skipped, so this should never panic
171            let index = field.index.expect("index must be set for fields that are not skipped") + offset;
172            let member = &field.member;
173            let serialize_member = match &field.serialize_with {
174                None => quote!(&self.#member),
175                Some(f) => {
176                    let ty = &field.ty;
177                    quote!({
178                            struct __InternalSerdeIndexedSerializeWith #impl_generics_serialize {
179                                value: &'__serde_indexed_lifetime #ty,
180                                phantom: ::core::marker::PhantomData<#ident #ty_generics>,
181                            }
182
183                            impl #impl_generics_serialize serde::Serialize for __InternalSerdeIndexedSerializeWith #ty_generics_serialize #where_clause {
184                                fn serialize<__S>(
185                                    &self,
186                                    __s: __S,
187                                ) -> ::core::result::Result<__S::Ok, __S::Error>
188                                where
189                                    __S: serde::Serializer,
190                                {
191                                    #f(self.value, __s)
192                                }
193                            }
194
195                            &__InternalSerdeIndexedSerializeWith { value: &self.#member, phantom: ::core::marker::PhantomData::<#ident #ty_generics> }
196                    })
197                }
198            };
199
200            // println!("field {:?} index {:?}", &field.label, field.index);
201            match &field.skip_serializing_if {
202                Skip::If(path) => quote! {
203                    if !#path(&self.#member) {
204                        map.serialize_entry(&#index, #serialize_member)?;
205                    }
206                },
207                Skip::Always => unreachable!(),
208                Skip::Never => quote! {
209                    map.serialize_entry(&#index, #serialize_member)?;
210                },
211            }
212        })
213        .collect()
214}
215
216fn count_serialized_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
217    fields
218        .iter()
219        .map(|field| {
220            // let index = field.index + offset;
221            let member = &field.member;
222            match &field.skip_serializing_if {
223                Skip::If(path) => {
224                    quote! { if #path(&self.#member) { 0 } else { 1 } }
225                }
226                Skip::Always => quote! { 0 },
227
228                Skip::Never => {
229                    quote! { 1 }
230                }
231            }
232        })
233        .collect()
234}
235
236#[proc_macro_derive(SerializeIndexed, attributes(serde, serde_indexed))]
237pub fn derive_serialize(input: TokenStream) -> TokenStream {
238    let input = parse_macro_input!(input as Input);
239    let ident = input.ident;
240    let num_fields = count_serialized_fields(&input.fields);
241    let (_, ty_generics, where_clause) = input.generics.split_for_impl();
242    let mut generics_cl = input.generics.clone();
243    generics_cl.type_params_mut().for_each(|t| {
244        t.bounds
245            .push_value(TypeParamBound::Verbatim(quote!(serde::Serialize)));
246    });
247    let (impl_generics, _, _) = generics_cl.split_for_impl();
248
249    let mut generics_cl2 = generics_cl.clone();
250
251    generics_cl2
252        .params
253        .push(syn::GenericParam::Lifetime(LifetimeParam::new(
254            Lifetime::new("'__serde_indexed_lifetime", Span::call_site()),
255        )));
256
257    let (impl_generics_serialize, ty_generics_serialize, _) = generics_cl2.split_for_impl();
258
259    let serialize_fields = serialize_fields(
260        &input.fields,
261        input.attrs.offset,
262        impl_generics_serialize,
263        ty_generics_serialize,
264        &ty_generics,
265        where_clause,
266        &ident,
267    );
268
269    TokenStream::from(quote! {
270        #[automatically_derived]
271        impl #impl_generics serde::Serialize for #ident #ty_generics #where_clause  {
272            fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
273            where
274                S: serde::Serializer
275            {
276                use serde::ser::SerializeMap;
277                let num_fields = 0 #( + #num_fields)*;
278                let mut map = serializer.serialize_map(Some(num_fields))?;
279
280                #(#serialize_fields)*
281
282                map.end()
283            }
284        }
285    })
286}
287
288fn none_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
289    fields
290        .iter()
291        .filter(|f| !f.skip_serializing_if.is_always())
292        .map(|field| {
293            let ident = format_ident!("{}", &field.label);
294            let span = field.original_span;
295            quote_spanned! { span =>
296                let mut #ident = None;
297            }
298        })
299        .collect()
300}
301
302fn unwrap_expected_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
303    fields
304        .iter()
305        .map(|field| {
306            let label = field.label.clone();
307            let ident = format_ident!("{}", &field.label);
308            let span = field.original_span;
309            match field.skip_serializing_if {
310                Skip::Never => quote! {
311                    let #ident = #ident.ok_or_else(|| serde::de::Error::missing_field(#label))?;
312                },
313                Skip::If(_) => quote_spanned! { span =>
314                    let #ident = #ident.unwrap_or_default();
315                },
316                Skip::Always => quote! {
317                    let #ident = ::core::default::Default::default();
318                },
319            }
320        })
321        .collect()
322}
323
324fn match_fields(
325    fields: &[parse::Field],
326    offset: usize,
327    impl_generics_with_de: &ImplGenerics<'_>,
328    ty_generics: &TypeGenerics<'_>,
329    ty_generics_with_de: &TypeGenerics<'_>,
330    where_clause: Option<&WhereClause>,
331    struct_ident: &Ident,
332) -> Vec<proc_macro2::TokenStream> {
333    fields
334        .iter()
335        .filter(|f| !f.skip_serializing_if.is_always())
336        .map(|field| {
337            let label = field.label.clone();
338            let ident = format_ident!("{}", &field.label);
339            // index should only be none if the field is always skipped, so this should never panic
340            let index = field.index.expect("index must be set for fields that are not skipped") + offset;
341            let span = field.original_span;
342
343            let next_value = match &field.deserialize_with {
344                Some(f) => {
345                    let ty = &field.ty;
346                    quote_spanned!(span => {
347                            struct __InternalSerdeIndexedDeserializeWith #impl_generics_with_de {
348                                value: #ty,
349                                phantom: ::core::marker::PhantomData<#struct_ident #ty_generics>,
350                                lifetime: ::core::marker::PhantomData<&'de ()>,
351                            }
352                            impl #impl_generics_with_de serde::Deserialize<'de> for __InternalSerdeIndexedDeserializeWith #ty_generics_with_de #where_clause {
353                                fn deserialize<__D>(
354                                    __deserializer: __D,
355                                ) -> Result<Self, __D::Error>
356                                where
357                                    __D: serde::Deserializer<'de>,
358                                {
359
360                                    Ok(__InternalSerdeIndexedDeserializeWith {
361                                        value: #f(__deserializer)?,
362                                        phantom: ::core::marker::PhantomData,
363                                        lifetime: ::core::marker::PhantomData,
364                                    })
365                                }
366                            }
367
368                            let __InternalSerdeIndexedDeserializeWith { value, lifetime: _, phantom: _ } = map.next_value()?;
369                            value
370                        }
371                    )
372                }
373                None => quote_spanned!(span => map.next_value()?),
374            };
375
376            quote_spanned!{ span =>
377                #index => {
378                    if #ident.is_some() {
379                        return Err(serde::de::Error::duplicate_field(#label));
380                    }
381                    let next_value = #next_value;
382                    #ident = Some(next_value);
383                },
384            }
385        })
386        .collect()
387}
388
389fn all_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
390    fields
391        .iter()
392        .map(|field| {
393            let ident = format_ident!("{}", &field.label);
394            let span = field.original_span;
395            quote_spanned! { span =>
396                #ident
397            }
398        })
399        .collect()
400}
401
402#[proc_macro_derive(DeserializeIndexed, attributes(serde, serde_indexed))]
403pub fn derive_deserialize(input: TokenStream) -> TokenStream {
404    let input = parse_macro_input!(input as Input);
405    let ident = input.ident;
406    let none_fields = none_fields(&input.fields);
407    let unwrap_expected_fields = unwrap_expected_fields(&input.fields);
408    let all_fields = all_fields(&input.fields);
409
410    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
411
412    let mut generics_cl = input.generics.clone();
413    generics_cl.params.insert(
414        0,
415        syn::GenericParam::Lifetime(LifetimeParam {
416            attrs: Vec::new(),
417            lifetime: Lifetime {
418                apostrophe: Span::call_site(),
419                ident: Ident::new("de", Span::call_site()),
420            },
421            colon_token: None,
422            bounds: input
423                .generics
424                .lifetimes()
425                .map(|l| l.lifetime.clone())
426                .collect(),
427        }),
428    );
429    generics_cl.type_params_mut().for_each(|t| {
430        t.bounds
431            .push_value(TypeParamBound::Verbatim(quote!(serde::Deserialize<'de>)));
432    });
433
434    let (impl_generics_with_de, ty_generics_with_de, _) = generics_cl.split_for_impl();
435
436    let match_fields = match_fields(
437        &input.fields,
438        input.attrs.offset,
439        &impl_generics_with_de,
440        &ty_generics,
441        &ty_generics_with_de,
442        where_clause,
443        &ident,
444    );
445
446    let the_loop = if !input.fields.is_empty() {
447        // NB: In the previous "none_fields", we use the actual struct's
448        // keys as variable names. If the struct happens to have a key
449        // named "key", it would clash with __serde_indexed_internal_key,
450        // if that were named key.
451        quote! {
452            while let Some(__serde_indexed_internal_key) = map.next_key()? {
453                match __serde_indexed_internal_key {
454                    #(#match_fields)*
455                    _ => {
456                        // Ignore unknown keys by consuming their value
457                        let _ = map.next_value::<serde::de::IgnoredAny>()?;
458                    }
459                }
460            }
461        }
462    } else {
463        quote! {}
464    };
465
466    let res = quote! {
467        #[automatically_derived]
468        impl #impl_generics_with_de serde::Deserialize<'de> for #ident #ty_generics #where_clause {
469            fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
470            where
471                D: serde::Deserializer<'de>,
472            {
473                struct IndexedVisitor #impl_generics (core::marker::PhantomData<#ident #ty_generics>);
474
475                impl #impl_generics_with_de serde::de::Visitor<'de> for IndexedVisitor #ty_generics {
476                    type Value = #ident #ty_generics;
477
478                    fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
479                        formatter.write_str(stringify!(#ident))
480                    }
481
482                    fn visit_map<V>(self, mut map: V) -> core::result::Result<Self::Value, V::Error>
483                    where
484                        V: serde::de::MapAccess<'de>,
485                    {
486                        #(#none_fields)*
487
488                        #the_loop
489
490                        #(#unwrap_expected_fields)*
491
492                        Ok(#ident { #(#all_fields),* })
493                    }
494                }
495
496                deserializer.deserialize_map(IndexedVisitor(Default::default()))
497            }
498        }
499    };
500    TokenStream::from(res)
501}