Skip to main content

serde_feather_macros/
lib.rs

1#![forbid(unsafe_code)]
2
3//! Proc-macro derive implementation for `serde-feather`.
4
5use std::collections::HashMap;
6
7use proc_macro::TokenStream;
8use proc_macro2::{Span, TokenStream as TokenStream2};
9use proc_macro_crate::{crate_name, FoundCrate};
10use quote::{format_ident, quote};
11use syn::{
12    ext::IdentExt, parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Field,
13    Fields, Ident, LitStr,
14};
15
16#[proc_macro_derive(FeatherSerialize, attributes(serde))]
17pub fn derive_feather_serialize(input: TokenStream) -> TokenStream {
18    let input = parse_macro_input!(input as DeriveInput);
19    match expand_serialize(&input) {
20        Ok(output) => output.into(),
21        Err(error) => error.into_compile_error().into(),
22    }
23}
24
25#[proc_macro_derive(FeatherDeserialize, attributes(serde))]
26pub fn derive_feather_deserialize(input: TokenStream) -> TokenStream {
27    let input = parse_macro_input!(input as DeriveInput);
28    match expand_deserialize(&input) {
29        Ok(output) => output.into(),
30        Err(error) => error.into_compile_error().into(),
31    }
32}
33
34struct ContainerAttrOptions {
35    rename: Option<LitStr>,
36}
37
38#[derive(Default)]
39struct FieldAttrOptions {
40    rename: Option<LitStr>,
41    default: bool,
42    skip_serializing: bool,
43    skip_deserializing: bool,
44}
45
46struct ParsedField {
47    ident: Ident,
48    ty: syn::Type,
49    serialized_name: LitStr,
50    default: bool,
51    skip_serializing: bool,
52    skip_deserializing: bool,
53}
54
55struct ParsedStruct {
56    ident: Ident,
57    struct_name: LitStr,
58    fields: Vec<ParsedField>,
59}
60
61#[derive(Clone, Copy)]
62enum WireDirection {
63    Serialize,
64    Deserialize,
65}
66
67impl WireDirection {
68    fn includes(self, field: &ParsedField) -> bool {
69        match self {
70            Self::Serialize => !field.skip_serializing,
71            Self::Deserialize => !field.skip_deserializing,
72        }
73    }
74
75    fn name(self) -> &'static str {
76        match self {
77            Self::Serialize => "serialization",
78            Self::Deserialize => "deserialization",
79        }
80    }
81}
82
83fn expand_serialize(input: &DeriveInput) -> syn::Result<TokenStream2> {
84    let parsed = parse_input(input, "FeatherSerialize")?;
85    let crate_path = serde_feather_path();
86
87    let included_fields: Vec<&ParsedField> = parsed
88        .fields
89        .iter()
90        .filter(|field| !field.skip_serializing)
91        .collect();
92
93    let field_count = included_fields.len();
94    let serialize_fields = included_fields.into_iter().map(|field| {
95        let field_ident = &field.ident;
96        let field_name = &field.serialized_name;
97        quote! {
98            #crate_path::serde::ser::SerializeStruct::serialize_field(
99                &mut state,
100                #field_name,
101                &self.#field_ident,
102            )?;
103        }
104    });
105
106    let struct_ident = &parsed.ident;
107    let struct_name = &parsed.struct_name;
108
109    Ok(quote! {
110        impl #crate_path::serde::ser::Serialize for #struct_ident {
111            fn serialize<S>(
112                &self,
113                serializer: S,
114            ) -> ::core::result::Result<S::Ok, S::Error>
115            where
116                S: #crate_path::serde::ser::Serializer,
117            {
118                let mut state = #crate_path::serde::ser::Serializer::serialize_struct(
119                    serializer,
120                    #struct_name,
121                    #field_count,
122                )?;
123                #(#serialize_fields)*
124                #crate_path::serde::ser::SerializeStruct::end(state)
125            }
126        }
127    })
128}
129
130fn expand_deserialize(input: &DeriveInput) -> syn::Result<TokenStream2> {
131    let parsed = parse_input(input, "FeatherDeserialize")?;
132    let crate_path = serde_feather_path();
133
134    struct DeserBinding {
135        field_index: usize,
136        binding_ident: Ident,
137        field_name: LitStr,
138        field_ty: syn::Type,
139        default: bool,
140    }
141
142    let mut bindings = Vec::<DeserBinding>::new();
143    for (index, field) in parsed.fields.iter().enumerate() {
144        if field.skip_deserializing {
145            continue;
146        }
147
148        bindings.push(DeserBinding {
149            field_index: index,
150            binding_ident: format_ident!("__feather_field_{index}"),
151            field_name: field.serialized_name.clone(),
152            field_ty: field.ty.clone(),
153            default: field.default,
154        });
155    }
156
157    let field_bindings: Vec<TokenStream2> = bindings
158        .iter()
159        .map(|binding| {
160            let binding_ident = &binding.binding_ident;
161            let field_ty = &binding.field_ty;
162            quote! { let mut #binding_ident: ::core::option::Option<#field_ty> = ::core::option::Option::None; }
163        })
164        .collect();
165    let field_bindings_in_map = field_bindings.clone();
166    let field_bindings_in_seq = field_bindings;
167
168    let field_setter_match_arms = bindings.iter().enumerate().map(|(binding_index, binding)| {
169        let field_index = binding_index;
170        let binding_ident = &binding.binding_ident;
171        let field_name = &binding.field_name;
172        let field_ty = &binding.field_ty;
173        quote! {
174            #field_index => {
175                if #binding_ident.is_some() {
176                    return ::core::result::Result::Err(
177                        #crate_path::serde::de::Error::duplicate_field(#field_name),
178                    );
179                }
180                #binding_ident = ::core::option::Option::Some(#crate_path::serde::de::MapAccess::next_value::<#field_ty>(&mut map)?);
181            }
182        }
183    });
184
185    let known_fields: Vec<LitStr> = bindings
186        .iter()
187        .map(|binding| binding.field_name.clone())
188        .collect();
189    let known_fields_in_map = known_fields.clone();
190
191    let construct_fields: Vec<TokenStream2> = parsed
192        .fields
193        .iter()
194        .enumerate()
195        .map(|(index, field)| {
196            let field_ident = &field.ident;
197            let field_name = &field.serialized_name;
198            if field.skip_deserializing {
199                return quote! {
200                    #field_ident: ::core::default::Default::default()
201                };
202            }
203
204            let binding_ident = bindings
205                .iter()
206                .find(|binding| binding.field_index == index)
207                .expect("binding for non-skipped field")
208                .binding_ident
209                .clone();
210
211            if field.default {
212                quote! {
213                    #field_ident: #binding_ident.unwrap_or_default()
214                }
215            } else {
216                quote! {
217                    #field_ident: match #binding_ident {
218                        ::core::option::Option::Some(value) => value,
219                        ::core::option::Option::None => {
220                            return ::core::result::Result::Err(
221                                #crate_path::serde::de::Error::missing_field(#field_name),
222                            );
223                        }
224                    }
225                }
226            }
227        })
228        .collect();
229    let construct_fields_in_map = construct_fields.clone();
230    let construct_fields_in_seq = construct_fields;
231
232    let seq_field_decode_steps = bindings.iter().enumerate().map(|(seq_index, binding)| {
233        let binding_ident = &binding.binding_ident;
234        let field_ty = &binding.field_ty;
235        if binding.default {
236            quote! {
237                if let ::core::option::Option::Some(value) =
238                    #crate_path::serde::de::SeqAccess::next_element::<#field_ty>(&mut seq)?
239                {
240                    #binding_ident = ::core::option::Option::Some(value);
241                }
242            }
243        } else {
244            quote! {
245                #binding_ident =
246                    match #crate_path::serde::de::SeqAccess::next_element::<#field_ty>(&mut seq)? {
247                        ::core::option::Option::Some(value) => ::core::option::Option::Some(value),
248                        ::core::option::Option::None => {
249                            return ::core::result::Result::Err(
250                                #crate_path::serde::de::Error::invalid_length(#seq_index, &self),
251                            );
252                        }
253                    };
254            }
255        }
256    });
257
258    let seq_expected_len = bindings.len();
259
260    let struct_ident = &parsed.ident;
261    let struct_name = &parsed.struct_name;
262
263    Ok(quote! {
264        impl<'de> #crate_path::serde::de::Deserialize<'de> for #struct_ident {
265            fn deserialize<D>(deserializer: D) -> ::core::result::Result<Self, D::Error>
266            where
267                D: #crate_path::serde::de::Deserializer<'de>,
268            {
269                struct __FeatherVisitor;
270
271                impl<'de> #crate_path::serde::de::Visitor<'de> for __FeatherVisitor {
272                    type Value = #struct_ident;
273
274                    fn expecting(
275                        &self,
276                        formatter: &mut ::core::fmt::Formatter<'_>,
277                    ) -> ::core::fmt::Result {
278                        ::core::write!(formatter, "struct {}", #struct_name)
279                    }
280
281                    fn visit_map<V>(
282                        self,
283                        mut map: V,
284                    ) -> ::core::result::Result<Self::Value, V::Error>
285                    where
286                        V: #crate_path::serde::de::MapAccess<'de>,
287                    {
288                        const __FEATHER_FIELDS: &[&str] = &[#(#known_fields_in_map),*];
289                        #(#field_bindings_in_map)*
290                        while let ::core::option::Option::Some(key) = #crate_path::serde::de::MapAccess::next_key::<#crate_path::__private::OwnedFieldName>(&mut map)?
291                        {
292                            match #crate_path::__private::select_field_index(key.as_str(), __FEATHER_FIELDS) {
293                                ::core::option::Option::Some(index) => match index {
294                                    #(#field_setter_match_arms)*
295                                    _ => {
296                                        let _: #crate_path::serde::de::IgnoredAny =
297                                            #crate_path::serde::de::MapAccess::next_value(&mut map)?;
298                                    }
299                                },
300                                ::core::option::Option::None => {
301                                    let _: #crate_path::serde::de::IgnoredAny =
302                                        #crate_path::serde::de::MapAccess::next_value(&mut map)?;
303                                }
304                            }
305                        }
306
307                        ::core::result::Result::Ok(#struct_ident {
308                            #(#construct_fields_in_map,)*
309                        })
310                    }
311
312                    fn visit_seq<V>(
313                        self,
314                        mut seq: V,
315                    ) -> ::core::result::Result<Self::Value, V::Error>
316                    where
317                        V: #crate_path::serde::de::SeqAccess<'de>,
318                    {
319                        #(#field_bindings_in_seq)*
320                        #(#seq_field_decode_steps)*
321
322                        if #crate_path::serde::de::SeqAccess::next_element::<#crate_path::serde::de::IgnoredAny>(&mut seq)?.is_some() {
323                            return ::core::result::Result::Err(
324                                #crate_path::serde::de::Error::invalid_length(#seq_expected_len + 1, &self),
325                            );
326                        }
327
328                        ::core::result::Result::Ok(#struct_ident {
329                            #(#construct_fields_in_seq,)*
330                        })
331                    }
332                }
333
334                const __FEATHER_FIELDS: &[&str] = &[#(#known_fields),*];
335                #crate_path::serde::de::Deserializer::deserialize_struct(
336                    deserializer,
337                    #struct_name,
338                    __FEATHER_FIELDS,
339                    __FeatherVisitor,
340                )
341            }
342        }
343    })
344}
345
346fn parse_input(input: &DeriveInput, macro_name: &str) -> syn::Result<ParsedStruct> {
347    if !input.generics.params.is_empty() || input.generics.where_clause.is_some() {
348        return Err(syn::Error::new_spanned(
349            &input.generics,
350            format!("{macro_name} only supports non-generic structs in this MVP"),
351        ));
352    }
353
354    let container_options = parse_container_attributes(&input.attrs)?;
355    let struct_name = container_options
356        .rename
357        .unwrap_or_else(|| LitStr::new(&input.ident.to_string(), input.ident.span()));
358
359    let named_fields = match &input.data {
360        Data::Struct(data_struct) => match &data_struct.fields {
361            Fields::Named(fields) => &fields.named,
362            _ => {
363                return Err(syn::Error::new_spanned(
364                    &data_struct.fields,
365                    format!("{macro_name} only supports structs with named fields in this MVP"),
366                ))
367            }
368        },
369        _ => {
370            return Err(syn::Error::new_spanned(
371                &input.ident,
372                format!("{macro_name} only supports structs in this MVP"),
373            ))
374        }
375    };
376
377    let mut parsed_fields = Vec::with_capacity(named_fields.len());
378    for field in named_fields {
379        parsed_fields.push(parse_field(field)?);
380    }
381
382    validate_unique_wire_field_names(&parsed_fields, WireDirection::Serialize)?;
383    validate_unique_wire_field_names(&parsed_fields, WireDirection::Deserialize)?;
384
385    Ok(ParsedStruct {
386        ident: input.ident.clone(),
387        struct_name,
388        fields: parsed_fields,
389    })
390}
391
392fn validate_unique_wire_field_names(
393    parsed_fields: &[ParsedField],
394    direction: WireDirection,
395) -> syn::Result<()> {
396    let mut seen_by_name: HashMap<String, String> = HashMap::new();
397
398    for field in parsed_fields {
399        if !direction.includes(field) {
400            continue;
401        }
402
403        let wire_name = field.serialized_name.value();
404        let current_field = field.ident.to_string();
405        if let Some(previous_field) = seen_by_name.insert(wire_name.clone(), current_field) {
406            return Err(syn::Error::new(
407                field.serialized_name.span(),
408                format!(
409                    "duplicate wire field name `{wire_name}` in {}; conflicts with field \
410                     `{previous_field}`",
411                    direction.name()
412                ),
413            ));
414        }
415    }
416
417    Ok(())
418}
419
420fn parse_container_attributes(attrs: &[Attribute]) -> syn::Result<ContainerAttrOptions> {
421    let mut options = ContainerAttrOptions { rename: None };
422
423    for attr in attrs {
424        if !attr.path().is_ident("serde") {
425            continue;
426        }
427
428        attr.parse_nested_meta(|meta| {
429            if meta.path.is_ident("rename") {
430                let rename_value: LitStr = meta.value()?.parse()?;
431                if options.rename.replace(rename_value).is_some() {
432                    return Err(meta.error("duplicate serde container attribute `rename`"));
433                }
434                return Ok(());
435            }
436
437            Err(meta.error("unsupported serde container attribute; supported attributes: `rename`"))
438        })?;
439    }
440
441    Ok(options)
442}
443
444fn parse_field(field: &Field) -> syn::Result<ParsedField> {
445    let field_ident = field.ident.clone().ok_or_else(|| {
446        syn::Error::new(
447            field.span(),
448            "Feather derives only support fields with identifiers",
449        )
450    })?;
451
452    let mut options = FieldAttrOptions::default();
453
454    for attr in &field.attrs {
455        if !attr.path().is_ident("serde") {
456            continue;
457        }
458
459        attr.parse_nested_meta(|meta| {
460            if meta.path.is_ident("rename") {
461                let rename_value: LitStr = meta.value()?.parse()?;
462                if options.rename.replace(rename_value).is_some() {
463                    return Err(meta.error("duplicate serde field attribute `rename`"));
464                }
465                return Ok(());
466            }
467
468            if meta.path.is_ident("default") {
469                ensure_flag_meta_has_no_value(&meta, "default")?;
470                if options.default {
471                    return Err(meta.error("duplicate serde field attribute `default`"));
472                }
473                options.default = true;
474                return Ok(());
475            }
476
477            if meta.path.is_ident("skip") {
478                ensure_flag_meta_has_no_value(&meta, "skip")?;
479                if options.skip_serializing || options.skip_deserializing {
480                    return Err(meta.error(
481                        "serde field attribute `skip` conflicts with previously declared `skip`, \
482                         `skip_serializing`, or `skip_deserializing`",
483                    ));
484                }
485                options.skip_serializing = true;
486                options.skip_deserializing = true;
487                return Ok(());
488            }
489
490            if meta.path.is_ident("skip_serializing") {
491                ensure_flag_meta_has_no_value(&meta, "skip_serializing")?;
492                if options.skip_serializing {
493                    return Err(meta.error("duplicate serde field attribute `skip_serializing`"));
494                }
495                if options.skip_deserializing {
496                    return Err(meta.error(
497                        "serde field attributes `skip_serializing` and `skip_deserializing` \
498                         cannot be combined",
499                    ));
500                }
501                options.skip_serializing = true;
502                return Ok(());
503            }
504
505            if meta.path.is_ident("skip_deserializing") {
506                ensure_flag_meta_has_no_value(&meta, "skip_deserializing")?;
507                if options.skip_deserializing {
508                    return Err(meta.error("duplicate serde field attribute `skip_deserializing`"));
509                }
510                if options.skip_serializing {
511                    return Err(meta.error(
512                        "serde field attributes `skip_serializing` and `skip_deserializing` \
513                         cannot be combined",
514                    ));
515                }
516                options.skip_deserializing = true;
517                return Ok(());
518            }
519
520            Err(meta.error(
521                "unsupported serde field attribute; supported attributes: `rename`, `default`, \
522                 `skip`, `skip_serializing`, `skip_deserializing`",
523            ))
524        })?;
525    }
526
527    let serialized_name = options
528        .rename
529        .unwrap_or_else(|| LitStr::new(&field_ident.unraw().to_string(), field_ident.span()));
530
531    Ok(ParsedField {
532        ident: field_ident,
533        ty: field.ty.clone(),
534        serialized_name,
535        default: options.default,
536        skip_serializing: options.skip_serializing,
537        skip_deserializing: options.skip_deserializing,
538    })
539}
540
541fn ensure_flag_meta_has_no_value(
542    meta: &syn::meta::ParseNestedMeta<'_>,
543    name: &str,
544) -> syn::Result<()> {
545    if !meta.input.peek(syn::Token![=]) && !meta.input.peek(syn::token::Paren) {
546        return Ok(());
547    }
548
549    Err(meta.error(format!(
550        "serde field attribute `{name}` does not accept a value"
551    )))
552}
553
554fn serde_feather_path() -> TokenStream2 {
555    match crate_name("serde-feather") {
556        Ok(FoundCrate::Itself) => quote!(crate),
557        Ok(FoundCrate::Name(name)) => {
558            let ident = Ident::new(&name.replace('-', "_"), Span::call_site());
559            quote!(::#ident)
560        }
561        Err(_) => quote!(::serde_feather),
562    }
563}