serde_partial_macro/
lib.rs

1use proc_macro::{Span, TokenStream};
2use quote::ToTokens;
3use serde_derive_internals::{
4    ast::{Container, Data, Style},
5    Ctxt, Derive,
6};
7use syn::{DeriveInput, Error};
8
9#[proc_macro_derive(SerializePartial, attributes(serde))]
10pub fn serialize_partial(input: TokenStream) -> TokenStream {
11    let cx = Ctxt::new();
12    let item = syn::parse_macro_input!(input as DeriveInput);
13    let Container {
14        data,
15        attrs,
16        ident,
17        original,
18        ..
19    } = match Container::from_ast(&cx, &item, Derive::Serialize) {
20        Some(c) => c,
21        None => return item.to_token_stream().into(),
22    };
23    let ident = &ident;
24    let vis = &original.vis;
25
26    if cx.check().is_err() {
27        return item.to_token_stream().into();
28    }
29
30    let mut fields = match data {
31        Data::Struct(Style::Struct, f) => f,
32        _ => {
33            return Error::new(
34                Span::call_site().into(),
35                "SerializePartial only supports structs",
36            )
37            .to_compile_error()
38            .into()
39        }
40    };
41    for f in fields.iter_mut() {
42        f.attrs.rename_by_rules(attrs.rename_all_rules());
43    }
44    fields.retain(|f| !f.attrs.skip_serializing());
45
46    let field_idents = fields
47        .iter()
48        .map(|f| f.original.ident.as_ref().unwrap())
49        .collect::<Vec<_>>();
50    let field_idents = &field_idents;
51
52    let field_names = fields
53        .iter()
54        .map(|f| f.attrs.name().serialize_name())
55        .collect::<Vec<_>>();
56    let field_names = &field_names;
57
58    let fields_len = fields.len();
59
60    let fields_struct_ident = &quote::format_ident!("{}Fields", ident);
61    let filter_struct_ident = &quote::format_ident!("{}Filter", ident);
62
63    let fields_struct = quote::quote! {
64        #[derive(Debug, Clone, Copy)]
65        #vis struct #fields_struct_ident {
66            #(
67                pub #field_idents: ::serde_partial::Field<'static, #ident>,
68            )*
69        }
70
71        impl #fields_struct_ident {
72            pub const FIELDS: Self = Self {
73                #(
74                    #field_idents: ::serde_partial::Field::new(#field_names),
75                )*
76            };
77        }
78
79        impl ::core::iter::IntoIterator for #fields_struct_ident {
80            type Item = ::serde_partial::Field<'static, #ident>;
81            type IntoIter = ::core::array::IntoIter<Self::Item, #fields_len>;
82
83            fn into_iter(self) -> Self::IntoIter {
84                #[allow(deprecated)]
85                ::core::array::IntoIter::new([
86                    #(
87                        self.#field_idents,
88                    )*
89                ])
90            }
91        }
92    };
93
94    let filter_struct = quote::quote! {
95        #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
96        #vis struct #filter_struct_ident {
97            #(
98                #field_idents: bool,
99            )*
100        }
101
102        impl ::serde_partial::SerializeFilter<#ident> for #filter_struct_ident {
103            fn skip(&self, field: ::serde_partial::Field<'_, #ident>) -> bool {
104                match field.name() {
105                    #(
106                        #field_names => !self.#field_idents,
107                    )*
108                    _ => panic!("unknown field"),
109                }
110            }
111
112            fn filtered_len(&self, _len: Option<usize>) -> Option<usize> {
113                let mut len = 0;
114                #(
115                    if self.#field_idents {
116                        len += 1;
117                    }
118                )*
119                Some(len)
120            }
121        }
122    };
123
124    let trait_impl = quote::quote! {
125        impl<'a> ::serde_partial::SerializePartial<'a> for #ident {
126            type Fields = #fields_struct_ident;
127            type Filter = #filter_struct_ident;
128
129            fn with_fields<F, I>(&'a self, select: F) -> ::serde_partial::Partial<'a, Self>
130            where
131                F: ::core::ops::FnOnce(Self::Fields) -> I,
132                I: ::core::iter::IntoIterator<Item = ::serde_partial::Field<'a, Self>>,
133            {
134                let fields = Self::Fields::FIELDS;
135                let mut filter = <Self::Filter as ::core::default::Default>::default();
136
137                for filtered in select(fields) {
138                    match filtered.name() {
139                        #(
140                            #field_names => { filter.#field_idents = true }
141                        )*
142                        _ => panic!("unknown field"),
143                    }
144                }
145
146                ::serde_partial::Partial {
147                    value: self,
148                    filter,
149                }
150            }
151        }
152    };
153
154    let derive = quote::quote! {
155        #[doc(hidden)]
156        #[allow(non_upper_case_globals, non_camel_case_types)]
157        const _: () = {
158            #fields_struct
159            #filter_struct
160            #trait_impl
161        };
162    };
163    derive.into()
164}