visit_diff_derive/
lib.rs

1//! Derives the `Diff` trait naively, using the literal structure of the
2//! datatype.
3
4extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use quote::{quote, quote_spanned};
8use std::iter::FromIterator;
9use syn;
10use syn::spanned::Spanned;
11
12#[proc_macro_derive(Diff)]
13pub fn diff_derive(input: TokenStream) -> TokenStream {
14    let input = syn::parse_macro_input!(input as syn::DeriveInput);
15
16    let name = input.ident;
17
18    let generics = add_trait_bounds(input.generics);
19    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
20
21    let dispatch = gen_dispatch(&name, &input.data);
22
23    let expanded = quote_spanned! {name.span()=>
24        impl #impl_generics ::visit_diff::Diff for #name #ty_generics
25        #where_clause {
26            fn diff<D>(a: &Self, b: &Self, out: D)
27                -> ::std::result::Result<D::Ok, D::Err>
28            where D: ::visit_diff::Differ
29            {
30                #dispatch
31            }
32        }
33    };
34
35    TokenStream::from(expanded)
36}
37
38/// Naively slaps a `Diff` bound on every generic type parameter. This leads to
39/// overconstrained impls but it's sure easy -- and it's essentially what the
40/// built in derives do.
41fn add_trait_bounds(mut generics: syn::Generics) -> syn::Generics {
42    for param in &mut generics.params {
43        if let syn::GenericParam::Type(type_param) = param {
44            type_param
45                .bounds
46                .push(syn::parse_quote!(::visit_diff::Diff));
47        }
48    }
49    generics
50}
51
52/// Generates the "dispatcher" body of `diff`, which turns around and calls
53/// methods on the `Differ` depending on type.
54fn gen_dispatch(ty: &syn::Ident, data: &syn::Data) -> proc_macro2::TokenStream {
55    match data {
56        syn::Data::Struct(data) => {
57            match &data.fields {
58                syn::Fields::Named(fields) => gen_named_struct(ty, fields),
59                syn::Fields::Unnamed(fields) => gen_unnamed_struct(ty, fields),
60                syn::Fields::Unit => {
61                    // A unit struct without fields. There is only one instance
62                    // of such a type, and so we know statically that our
63                    // arguments are the same.
64                    quote_spanned! {ty.span()=>
65                        out.same(&a, &b)
66                    }
67                }
68            }
69        }
70        syn::Data::Enum(data) => {
71            // Enums are more complex than structs, because each variant can
72            // have a different shape. We'll process the variants and generate
73            // the corresponding match arms.
74            let variants = data.variants.iter().map(|v| {
75                let name = &v.ident;
76                match &v.fields {
77                    syn::Fields::Named(fields) => {
78                        gen_named_variant(ty, name, fields)
79                    }
80                    syn::Fields::Unnamed(fields) => {
81                        gen_unnamed_variant(ty, name, fields)
82                    }
83                    syn::Fields::Unit => {
84                        // For a unit variant, we only need to check that both
85                        // sides use the same variant.
86                        quote_spanned! {v.span()=>
87                            (#ty::#name, #ty::#name) => out.same(a, b),
88                        }
89                    }
90                }
91            });
92            let variants = proc_macro2::TokenStream::from_iter(variants);
93
94            // Now combine the match arms into a valid match expression.
95            quote_spanned! {ty.span()=>
96                match (a, b) {
97                    #variants
98                    _ => out.difference(a, b),
99                }
100            }
101        }
102        syn::Data::Union(_) => {
103            unimplemented!("A `union` type cannot be meaningfully diffed")
104        }
105    }
106}
107
108/// Generates dispatcher for a named struct.
109///
110/// Named structs are different from enum variants with named fields, because of
111/// the different ways we access their fields.
112fn gen_named_struct(
113    ty: &syn::Ident,
114    fields: &syn::FieldsNamed,
115) -> proc_macro2::TokenStream {
116    // A traditional struct: named fields, curly braces, etc.
117    // Generated code will resemble:
118    //
119    //   let mut s = out.begin_struct("TypeName");
120    //   s.diff_field("field1", &a.field1, &b.field1);
121    //   s.diff_field("field2", &a.field2, &b.field2);
122    //   s.end()
123
124    // First, generate the `diff_field` statements.
125    let stmts = fields.named.iter().map(|f| {
126        let name = &f.ident;
127        quote_spanned! {f.span()=>
128            s.diff_field(stringify!(#name), &a.#name, &b.#name);
129        }
130    });
131    let stmts = proc_macro2::TokenStream::from_iter(stmts);
132
133    quote_spanned! {ty.span()=>
134        use ::visit_diff::StructDiffer;
135        let mut s = out.begin_struct(stringify!(#ty));
136        #stmts
137        s.end()
138    }
139}
140
141/// Generates dispatcher for a named enum variant.
142///
143/// Named structs are different from enum variants with named fields, because of
144/// the different ways we access their fields.
145fn gen_named_variant(
146    ty: &syn::Ident,
147    name: &syn::Ident,
148    fields: &syn::FieldsNamed,
149) -> proc_macro2::TokenStream {
150    // A variant with named fields is very much like a
151    // struct, except that we have to access the fields
152    // using pattern matching instead of dotted names.
153    //
154    // Generated match arm will resemble:
155    //
156    //   ( Ty::Var { f: f_a, v: v_a },
157    //     Ty::Var { f: f_b, v: v_b } ) => {
158    //       use ::visit_diff::StructDiffer;
159    //       let mut s = out.begin_struct_variant("Ty", "Var");
160    //       s.diff_field("f", f_a, f_b);
161    //       s.diff_field("v", v_a, v_b);
162    //       s.end()
163    //   },
164    let a_pat = named_fields_pattern(fields.named.iter(), "_a");
165    let b_pat = named_fields_pattern(fields.named.iter(), "_b");
166    let stmts = diff_named_fields(fields.named.iter(), "_a", "_b");
167    quote_spanned! {name.span()=>
168        ( #ty::#name { #a_pat },
169          #ty::#name { #b_pat }) => {
170            use ::visit_diff::StructDiffer;
171            let mut s = out.begin_struct_variant(
172                stringify!(#ty),
173                stringify!(#name),
174            );
175            #stmts
176            s.end()
177        },
178    }
179}
180
181/// Generates dispatcher for a struct with unnamed fields (i.e. a tuple struct).
182fn gen_unnamed_struct(
183    ty: &syn::Ident,
184    fields: &syn::FieldsUnnamed,
185) -> proc_macro2::TokenStream {
186    // A tuple struct: unnamed fields, parens. Generated code
187    // will resemble:
188    //
189    //   let mut s = out.begin_tuple("TypeName");
190    //   s.diff_field(&a.0, &b.0);
191    //   s.diff_field(&a.1, &b.1);
192    //   s.end()
193
194    // First, generate the `diff_field` statements.
195    let stmts = fields.unnamed.iter().enumerate().map(|(i, f)| {
196        let index = syn::Index::from(i);
197        quote_spanned! {f.span()=>
198            s.diff_field(&a.#index, &b.#index);
199        }
200    });
201    let stmts = proc_macro2::TokenStream::from_iter(stmts);
202    quote_spanned! {ty.span()=>
203        use ::visit_diff::TupleDiffer;
204        let mut s = out.begin_tuple(stringify!(#ty));
205        #stmts
206        s.end()
207    }
208}
209
210/// Generates dispatcher for an enum variant with unnamed fields (i.e. a tuple
211/// variant).
212fn gen_unnamed_variant(
213    ty: &syn::Ident,
214    name: &syn::Ident,
215    fields: &syn::FieldsUnnamed,
216) -> proc_macro2::TokenStream {
217    // A variant with unnamed fields is very much like a tuple struct, except
218    // that we have to access the fields by pattern matching instead of using
219    // dotted numbers.
220    //
221    // Generated match arm will resemble:
222    //   ( Ty::Var(a0, a1),
223    //     Ty::Var(b0, b1) ) => {
224    //       use ::visit_diff::TupletDiffer;
225    //       let mut s = out.begin_tuple("Ty");
226    //       s.diff_field(f_a, f_b);
227    //       s.diff_field(v_a, v_b);
228    //       s.end()
229    //   },
230    let a_pat = unnamed_fields_pattern(fields.unnamed.iter(), "a");
231    let b_pat = unnamed_fields_pattern(fields.unnamed.iter(), "b");
232    let stmts = diff_unnamed_fields(fields.unnamed.iter(), "a", "b");
233    quote_spanned! {name.span()=>
234        (#ty::#name(#a_pat), #ty::#name(#b_pat)) => {
235            use ::visit_diff::TupleDiffer;
236            let mut s = out.begin_tuple_variant(
237                stringify!(#ty),
238                stringify!(#name),
239            );
240            #stmts
241            s.end()
242        },
243    }
244}
245
246/// Generates a pattern match that captures named fields under new names. This
247/// is used to capture the values of fields in a named-field enum variant.
248///
249/// Because we are matching *two copies* of the variant, we can't use the simple
250/// struct field match syntax, as it would try to bind each name twice:
251///
252/// ```ignore
253/// (Variant { a, b, c }, Variant { a, b, c })
254/// ```
255///
256/// Instead, we use this function to generate newly named bindings for each
257/// field, suffixed by `suffix` -- which is different for the left and right
258/// side. So, if we used the suffix `_left` on one side and `_right` on the
259/// other, we'd get the following:
260///
261/// ```ignore
262/// (Variant { a: a_left, b: b_left, c: c_left },
263///  Variant { a: a_right, b: b_right, c: c_right })
264/// ```
265///
266/// (This function is only responsible for the portion *within* the curly braces
267/// above.)
268fn named_fields_pattern<'a, I>(
269    fields: I,
270    suffix: &str,
271) -> proc_macro2::TokenStream
272where
273    I: IntoIterator<Item = &'a syn::Field>,
274{
275    let pat = fields.into_iter().map(|f| {
276        let name = f.ident.as_ref().unwrap();
277        let suffixed =
278            syn::Ident::new(&format!("{}{}", name, suffix), name.span());
279        quote_spanned! {f.span()=> #name: #suffixed, }
280    });
281    proc_macro2::TokenStream::from_iter(pat)
282}
283
284/// Generates a pattern match that gives names to unnamed fields. This is used
285/// to capture the values of fields in a tuple-style enum variant.
286///
287/// We simply append a number to the given `prefix` for each field. So, if we
288/// used the prefix `left_` on one side and `right_` on the other, a three-field
289/// tuple enum variant would produce the following match pattern:
290///
291/// ```ignore
292/// (Variant(left_0, left_1, left_2), Variant(right_0, right_1, right_2))
293/// ```
294///
295/// (This function is only responsible for the portion *within* the inner
296/// parentheses above.)
297fn unnamed_fields_pattern<'a, I>(
298    fields: I,
299    prefix: &str,
300) -> proc_macro2::TokenStream
301where
302    I: IntoIterator<Item = &'a syn::Field>,
303{
304    let pat = fields.into_iter().enumerate().map(|(i, f)| {
305        let name = syn::Ident::new(&format!("{}{}", prefix, i), f.span());
306        quote_spanned! {f.span()=> #name, }
307    });
308    proc_macro2::TokenStream::from_iter(pat)
309}
310
311/// Given named fields bound by `named_fields_pattern`, generates code to apply
312/// the `StructDiffer` to each pair.
313fn diff_named_fields<'a, I>(
314    fields: I,
315    left_suffix: &str,
316    right_suffix: &str,
317) -> proc_macro2::TokenStream
318where
319    I: IntoIterator<Item = &'a syn::Field>,
320{
321    let stmts = fields.into_iter().map(|f| {
322        let name = f.ident.as_ref().unwrap();
323        let left =
324            syn::Ident::new(&format!("{}{}", name, left_suffix), name.span());
325        let right =
326            syn::Ident::new(&format!("{}{}", name, right_suffix), name.span());
327        quote_spanned! {f.span()=>
328            s.diff_field(stringify!(#name), #left, #right);
329        }
330    });
331    proc_macro2::TokenStream::from_iter(stmts)
332}
333
334/// Given unnamed fields bound by `unnamed_fields_pattern`, generates code to
335/// apply the `TupleDiffer` to each pair.
336fn diff_unnamed_fields<'a, I>(
337    fields: I,
338    left_prefix: &str,
339    right_prefix: &str,
340) -> proc_macro2::TokenStream
341where
342    I: IntoIterator<Item = &'a syn::Field>,
343{
344    let stmts = fields.into_iter().enumerate().map(|(i, f)| {
345        let left = syn::Ident::new(&format!("{}{}", left_prefix, i), f.span());
346        let right =
347            syn::Ident::new(&format!("{}{}", right_prefix, i), f.span());
348        quote_spanned! {f.span()=>
349            s.diff_field(#left, #right);
350        }
351    });
352    proc_macro2::TokenStream::from_iter(stmts)
353}