serde_roundtrip_derive/
lib.rs

1extern crate proc_macro;
2extern crate syn;
3#[macro_use]
4extern crate quote;
5
6use proc_macro::TokenStream;
7use syn::fold::Folder;
8use syn::fold::noop_fold_generics;
9use syn::fold::noop_fold_path;
10use syn::AngleBracketedParameterData;
11use syn::Generics;
12use syn::Ident;
13use syn::Lifetime;
14use syn::Path;
15use syn::PathParameters;
16use syn::PathSegment;
17use syn::PolyTraitRef;
18use syn::TraitBoundModifier;
19use syn::Ty;
20use syn::TyParam;
21use syn::TyParamBound;
22use syn::WhereClause;
23
24#[proc_macro_derive(RoundTrip)]
25pub fn round_trip(input: TokenStream) -> TokenStream {
26    let s = input.to_string();
27    let ast = syn::parse_macro_input(&s).unwrap();
28    let gen = impl_round_trip(&ast);
29    gen.parse().unwrap()
30}
31
32// Rename the generics in a generic type declaration.
33
34struct Renaming<'a> {
35    original: &'a Generics,
36    lifetime_prefix: &'a str,
37    ty_param_prefix: &'a str,
38}
39
40impl<'a> Folder for Renaming<'a> {
41    fn fold_generics(&mut self, generics: Generics) -> Generics {
42        let mut result = noop_fold_generics(self, generics);
43        for ty_param in &mut result.ty_params {
44            ty_param.ident = self.fold_ty_param_ident(ty_param.ident.clone());
45        }
46        result
47    }
48    fn fold_lifetime(&mut self, lifetime: Lifetime) -> Lifetime {
49        Lifetime { ident: self.fold_lifetime_ident(lifetime.ident) }
50    }
51    fn fold_path(&mut self, path: Path) -> Path {
52        let mut result = noop_fold_path(self, path);
53        if let Some((segment, rest)) = result.segments.split_first_mut() {
54            if rest.is_empty() && segment.parameters.is_empty() {
55                segment.ident = self.fold_ty_param_ident(segment.ident.clone());
56            }
57        }
58        result
59    }
60}
61
62impl<'a> Renaming<'a> {
63    fn fold_lifetime_ident(&mut self, ident: Ident) -> Ident {
64        self.original.lifetimes.iter()
65            .position(|original| original.lifetime.ident == ident)
66            .map(|index| syn::Ident::from(format!("{}{}", self.lifetime_prefix, index)))
67            .unwrap_or(ident)
68    }
69    fn fold_ty_param_ident(&mut self, ident: Ident) -> Ident {
70        self.original.ty_params.iter()
71            .position(|original| original.ident == ident)
72            .map(|index| syn::Ident::from(format!("{}{}", self.ty_param_prefix, index)))
73            .unwrap_or(ident)
74    }
75}
76
77// Convert an ident with its generic parameters to a path
78
79fn generic_path(ident: &Ident, generics: &Generics) -> Path {
80    Path {
81        global: false,
82        segments: vec![ PathSegment {
83            ident: ident.clone(),
84            parameters: PathParameters::AngleBracketed(AngleBracketedParameterData {
85                lifetimes: generics.lifetimes.iter()
86                    .map(|lifetime_def| lifetime_def.lifetime.clone())
87                    .collect(),
88                types: generics.ty_params.iter()
89                    .map(|ty_param| Ty::Path(None, Path::from(ty_param.ident.clone())))
90                    .collect(),
91                bindings: vec![],
92            }),
93        } ]
94    }
95}
96
97// A type bound
98
99fn ty_param_bound(text: &str) -> TyParamBound {
100    TyParamBound::Trait(
101        PolyTraitRef {
102            bound_lifetimes: vec![],
103            trait_ref: syn::parse::path(text).expect("Unexpected parse error"),
104        },
105        TraitBoundModifier::None,
106    )
107}
108
109// Derive a RoundTrip implementation
110
111fn impl_round_trip(ast: &syn::MacroInput) -> quote::Tokens {
112    let name = &ast.ident;
113
114    // If the original is Foo<'l, X, Y>, the target type is Foo<'b0, T0, T1>.
115    let mut target_renaming = Renaming { original: &ast.generics, lifetime_prefix: "'b", ty_param_prefix: "T" };
116    let mut target_generics = target_renaming.fold_generics(ast.generics.clone());
117    for ty_param in target_generics.ty_params.iter_mut() {
118        ty_param.bounds.push(ty_param_bound("::serde::Deserialize"));
119    }
120    let target_where_clause = target_generics.where_clause.clone();
121    let target_path = generic_path(&ast.ident, &target_generics);
122
123    // The target type parameter is T: SameDeserialization<SameAs=Foo<'b0, T0, T1>>.
124    let target_ty_param_bound = quote! { ::serde_roundtrip::SameDeserialization<SameAs=#target_path> };
125    let target_ty_param = TyParam {
126        attrs: vec![],
127        ident: Ident::from("T"),
128        bounds: vec![ty_param_bound(target_ty_param_bound.as_str())],
129        default: None,
130    };
131
132    // If the original is Foo<'l, X, Y>, the source type is Foo<'a0, S0, S1>.
133    let mut source_renaming = Renaming { original: &ast.generics, lifetime_prefix: "'a", ty_param_prefix: "S" };
134    let mut source_generics = source_renaming.fold_generics(ast.generics.clone());
135    for (ty_param, target_ty_param) in source_generics.ty_params.iter_mut().zip(target_generics.ty_params.iter()) {
136        let target_ty_param_ident = &target_ty_param.ident;
137        let text = quote! { ::serde_roundtrip::RoundTrip<#target_ty_param_ident> };
138        ty_param.bounds.push(ty_param_bound(text.as_str()));
139    }
140    let source_path = generic_path(&ast.ident, &source_generics);
141
142    // The whole thing is parameterized by 'a0, 'b0, S0, S1, T0, T1, T.
143    let all_generics = Generics {
144        lifetimes: source_generics.lifetimes.iter().cloned()
145            .chain(target_generics.lifetimes.iter().cloned())
146            .collect::<Vec<_>>(),
147        ty_params: source_generics.ty_params.iter().cloned()
148            .chain(target_generics.ty_params.iter().cloned())
149            .chain(::std::iter::once(target_ty_param))
150            .collect::<Vec<_>>(),
151        where_clause: WhereClause {
152            predicates: source_generics.where_clause.predicates.iter().cloned()
153                .chain(target_generics.where_clause.predicates.iter().cloned())
154                .collect::<Vec<_>>(),
155        },
156    };
157    let all_where_clause = all_generics.where_clause.clone();
158
159    // The recursive implementation of round_trip()
160
161    let round_trip = match ast.body {
162        syn::Body::Struct(syn::VariantData::Struct(ref body)) => {
163            let fields = body.iter()
164                .filter_map(|field| field.ident.as_ref())
165                .map(|ident| quote! { #ident: self.#ident.round_trip() })
166                .collect::<Vec<_>>();
167            quote! { #name { #(#fields),* } }
168        },
169        syn::Body::Struct(syn::VariantData::Tuple(ref body)) => {
170            let fields = (0..body.len())
171                .map(syn::Ident::from)
172                .map(|index| quote! { self.#index.round_trip() })
173                .collect::<Vec<_>>();
174            quote! { #name ( #(#fields),* ) }
175        },
176        syn::Body::Struct(syn::VariantData::Unit) => {
177            quote! { #name }
178        },
179        syn::Body::Enum(ref body) => {
180            let cases = body.iter()
181                .map(|case| {
182                    let unqualified_ident = &case.ident;
183                    let ident = quote! { #name::#unqualified_ident };
184                    match case.data {
185                        syn::VariantData::Struct(ref body) => {
186                            let idents = body.iter()
187                                .filter_map(|field| field.ident.as_ref())
188                                .collect::<Vec<_>>();;
189                            let cloned = idents.iter()
190                                .map(|ident| quote! { #ident: #ident.round_trip() })
191                                .collect::<Vec<_>>();
192                            quote! { #ident { #(ref #idents),* } => #ident { #(#cloned),* } }
193                        },
194                        syn::VariantData::Tuple(ref body) => {
195                            let idents = (0..body.len())
196                                .map(|index| syn::Ident::from(format!("x{}", index)))
197                                .collect::<Vec<_>>();
198                            let cloned = idents.iter()
199                                .map(|ident| quote! { #ident.round_trip() })
200                                .collect::<Vec<_>>();
201                            quote! { #ident ( #(ref #idents),* ) => #ident ( #(#cloned),* ) }
202                        },
203                        syn::VariantData::Unit => {
204                            quote! { #ident => #ident }
205                        },
206                    }
207                })
208                .collect::<Vec<_>>();
209            quote! { match *self { #(#cases),* } }
210        },
211    };
212
213    // Implement RoundTrip and SameDeserialization
214
215    quote! {
216        impl #all_generics ::serde_roundtrip::RoundTrip<T> for #source_path
217            #all_where_clause
218        {
219            fn round_trip(&self) -> T { T::from(#round_trip) }
220        }
221        impl #target_generics ::serde_roundtrip::SameDeserialization for #target_path
222            #target_where_clause
223        {
224            type SameAs = Self;
225            fn from(data: Self) -> Self { data }
226        }
227    }
228}