serde_split/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro_crate::FoundCrate;
4use syn::{Attribute, Data, DeriveInput, Lifetime, LifetimeParam, Meta, WherePredicate};
5
6fn find_serde_crate() -> proc_macro2::TokenStream {
7    match proc_macro_crate::crate_name("serde") {
8        Ok(FoundCrate::Itself) => quote::quote!(crate),
9        Ok(FoundCrate::Name(name)) => {
10            let ident = syn::Ident::new(name.as_str(), Span::call_site());
11            quote::quote!(::#ident)
12        }
13        Err(_) => {
14            panic!("serde is a co-dependency of serde-split")
15        }
16    }
17}
18
19fn filter_attrs(attrs: &mut Vec<Attribute>, is_json: bool) {
20    let replace = if is_json { "json" } else { "bin" };
21
22    let mut current = 0;
23    while current < attrs.len() {
24        if attrs[current].path().is_ident(replace) {
25            match &mut attrs[current].meta {
26                Meta::Path(path) => *path = syn::parse_quote!(serde),
27                Meta::List(list) => list.path = syn::parse_quote!(serde),
28                Meta::NameValue(name_value) => name_value.path = syn::parse_quote!(serde),
29            }
30        } else if !attrs[current].path().is_ident("serde") {
31            attrs.remove(current);
32            continue;
33        }
34
35        current += 1;
36    }
37}
38
39fn filter_data(input: &mut DeriveInput, is_json: bool) {
40    filter_attrs(&mut input.attrs, is_json);
41
42    match &mut input.data {
43        Data::Struct(data) => {
44            data.fields
45                .iter_mut()
46                .for_each(|field| filter_attrs(&mut field.attrs, is_json));
47        }
48        Data::Enum(data) => {
49            data.variants.iter_mut().for_each(|variant| {
50                filter_attrs(&mut variant.attrs, is_json);
51                variant
52                    .fields
53                    .iter_mut()
54                    .for_each(|field| filter_attrs(&mut field.attrs, is_json));
55            });
56        }
57        Data::Union(data) => {
58            data.fields.named.iter_mut().for_each(|field| {
59                filter_attrs(&mut field.attrs, is_json);
60            });
61        }
62    }
63}
64
65#[proc_macro_derive(Serialize, attributes(json, bin, serde))]
66pub fn derive_serialize(tokens: TokenStream) -> TokenStream {
67    let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
68
69    let ident = input.ident.clone();
70
71    let mut json = input.clone();
72    let mut bin = input;
73
74    filter_data(&mut json, true);
75    filter_data(&mut bin, false);
76
77    json.ident = quote::format_ident!("{}JsonImpl", ident);
78    bin.ident = quote::format_ident!("{}BinaryImpl", ident);
79
80    let json_ident = &json.ident;
81    let bin_ident = &bin.ident;
82
83    let ident_str = syn::LitStr::new(ident.to_string().as_str(), ident.span());
84
85    let serde = find_serde_crate();
86
87    let (impl_gen, ty_gen, where_clause) = bin.generics.split_for_impl();
88
89    let where_clause = if let Some(clause) = where_clause {
90        let mut clause = clause.clone();
91        clause
92            .predicates
93            .extend(bin.generics.params.iter().filter_map(|param| match param {
94                syn::GenericParam::Type(ty) => {
95                    let ident = &ty.ident;
96                    Some::<WherePredicate>(syn::parse_quote!(#ident: #serde::Serialize))
97                }
98                _ => None,
99            }));
100        Some(clause)
101    } else if !bin.generics.params.is_empty() {
102        let clauses = bin.generics.params.iter().filter_map(|param| match param {
103            syn::GenericParam::Type(ty) => {
104                let ident = &ty.ident;
105                Some::<WherePredicate>(syn::parse_quote!(#ident: #serde::Serialize))
106            }
107            _ => None,
108        });
109
110        Some(syn::parse_quote!(where #(#clauses,)*))
111    } else {
112        None
113    };
114
115    quote::quote! {
116        const _: () = {
117            #[derive(#serde::Serialize)]
118            #[serde(remote = #ident_str)]
119            #json
120
121            #[derive(#serde::Serialize)]
122            #[serde(remote = #ident_str)]
123            #bin
124
125            impl #impl_gen #serde::Serialize for #ident #ty_gen #where_clause {
126                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
127                    where S: #serde::Serializer
128                {
129                    if serializer.is_human_readable() {
130                        #json_ident::serialize(self, serializer)
131                    } else {
132                        #bin_ident::serialize(self, serializer)
133                    }
134                }
135            }
136        };
137    }
138    .into()
139}
140
141#[proc_macro_derive(Deserialize, attributes(json, bin, serde))]
142pub fn derive_deserialize(tokens: TokenStream) -> TokenStream {
143    let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
144
145    let ident = input.ident.clone();
146
147    let mut json = input.clone();
148    let mut bin = input;
149
150    filter_data(&mut json, true);
151    filter_data(&mut bin, false);
152
153    json.ident = quote::format_ident!("{}JsonImpl", ident);
154    bin.ident = quote::format_ident!("{}BinaryImpl", ident);
155
156    let json_ident = &json.ident;
157    let bin_ident = &bin.ident;
158
159    let ident_str = syn::LitStr::new(ident.to_string().as_str(), ident.span());
160
161    let serde = find_serde_crate();
162
163    let mut impl_generics = bin.generics.clone();
164
165    impl_generics.params.insert(
166        0,
167        syn::GenericParam::Lifetime(LifetimeParam::new(Lifetime::new("'de", Span::call_site()))),
168    );
169
170    let (_, ty_gen, where_clause) = bin.generics.split_for_impl();
171
172    let (impl_gen, _, _) = impl_generics.split_for_impl();
173
174    let where_clause = if let Some(clause) = where_clause {
175        let mut clause = clause.clone();
176        clause
177            .predicates
178            .extend(bin.generics.params.iter().filter_map(|param| match param {
179                syn::GenericParam::Type(ty) => {
180                    let ident = &ty.ident;
181                    Some::<WherePredicate>(syn::parse_quote!(#ident: #serde::Deserialize<'de>))
182                }
183                _ => None,
184            }));
185        Some(clause)
186    } else if !bin.generics.params.is_empty() {
187        let clauses = bin.generics.params.iter().filter_map(|param| match param {
188            syn::GenericParam::Type(ty) => {
189                let ident = &ty.ident;
190                Some::<WherePredicate>(syn::parse_quote!(#ident: #serde::Deserialize<'de>))
191            }
192            _ => None,
193        });
194
195        Some(syn::parse_quote!(where #(#clauses,)*))
196    } else {
197        None
198    };
199
200    quote::quote! {
201        const _: () = {
202            #[derive(#serde::Deserialize)]
203            #[serde(remote = #ident_str)]
204            #json
205
206            #[derive(#serde::Deserialize)]
207            #[serde(remote = #ident_str)]
208            #bin
209
210            impl #impl_gen #serde::Deserialize<'de> for #ident #ty_gen #where_clause {
211                fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
212                    where D: #serde::Deserializer<'de>
213                {
214                    if deserializer.is_human_readable() {
215                        #json_ident::deserialize(deserializer)
216                    } else {
217                        #bin_ident::deserialize(deserializer)
218                    }
219                }
220            }
221        };
222    }
223    .into()
224}