wasmbin_derive/
lib.rs

1// Copyright 2020 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15extern crate proc_macro;
16
17use quote::{quote, ToTokens};
18use std::borrow::Cow;
19use synstructure::{decl_derive, Structure, VariantInfo};
20
21macro_rules! syn_throw {
22    ($err:expr) => {
23        return syn::Error::to_compile_error(&$err)
24    };
25}
26
27macro_rules! syn_try {
28    ($expr:expr) => {
29        match $expr {
30            Ok(expr) => expr,
31            Err(err) => syn_throw!(err),
32        }
33    };
34}
35
36fn struct_discriminant<'v>(v: &VariantInfo<'v>) -> syn::Result<Option<Cow<'v, syn::Expr>>> {
37    v.ast()
38        .attrs
39        .iter()
40        .filter_map(|attr| match attr {
41            syn::Attribute {
42                style: syn::AttrStyle::Outer,
43                meta,
44                ..
45            } if meta.path().is_ident("wasmbin") => {
46                syn::custom_keyword!(discriminant);
47
48                Some(
49                    attr.parse_args_with(|parser: syn::parse::ParseStream| {
50                        parser.parse::<discriminant>()?;
51                        parser.parse::<syn::Token![=]>()?;
52                        parser.parse()
53                    })
54                    .map(Cow::Owned),
55                )
56            }
57            _ => None,
58        })
59        .try_fold(None, |prev, discriminant| {
60            let discriminant = discriminant?;
61            if let Some(prev) = prev {
62                let mut err = syn::Error::new_spanned(
63                    discriminant,
64                    "#[derive(Wasmbin)]: duplicate discriminant",
65                );
66                err.combine(syn::Error::new_spanned(
67                    prev,
68                    "#[derive(Wasmbin)]: previous discriminant here",
69                ));
70                return Err(err);
71            }
72            Ok(Some(discriminant))
73        })
74}
75
76fn gen_encode_discriminant(repr: &syn::Type, discriminant: &syn::Expr) -> proc_macro2::TokenStream {
77    quote!(<#repr as Encode>::encode(&#discriminant, w)?)
78}
79
80fn is_newtype_like(v: &VariantInfo) -> bool {
81    matches!(v.ast().fields, fields @ syn::Fields::Unnamed(_) if fields.len() == 1)
82}
83
84fn track_err_in_field(
85    mut res: proc_macro2::TokenStream,
86    v: &VariantInfo,
87    field: &syn::Field,
88    index: usize,
89) -> proc_macro2::TokenStream {
90    if !is_newtype_like(v) {
91        let field_name = match &field.ident {
92            Some(ident) => ident.to_string(),
93            None => index.to_string(),
94        };
95        res = quote!(#res.map_err(|err| err.in_path(PathItem::Name(#field_name))));
96    }
97    res
98}
99
100fn track_err_in_variant(
101    res: proc_macro2::TokenStream,
102    v: &VariantInfo,
103) -> proc_macro2::TokenStream {
104    use std::fmt::Write;
105
106    let mut variant_name = String::new();
107    if let Some(prefix) = v.prefix {
108        write!(variant_name, "{}::", prefix).unwrap();
109    }
110    write!(variant_name, "{}", v.ast().ident).unwrap();
111
112    quote!(#res.map_err(|err| err.in_path(PathItem::Variant(#variant_name))))
113}
114
115fn catch_expr(
116    res: proc_macro2::TokenStream,
117    err: proc_macro2::TokenStream,
118) -> proc_macro2::TokenStream {
119    quote!(
120        (move || -> Result<_, #err> {
121            Ok({ #res })
122        })()
123    )
124}
125
126fn gen_decode(v: &VariantInfo) -> proc_macro2::TokenStream {
127    let mut res = v.construct(|field, index| {
128        let res = track_err_in_field(quote!(Decode::decode(r)), v, field, index);
129        quote!(#res?)
130    });
131    res = catch_expr(res, quote!(DecodeError));
132    res = track_err_in_variant(res, v);
133    res
134}
135
136fn parse_repr(s: &Structure) -> syn::Result<syn::Type> {
137    s.ast()
138        .attrs
139        .iter()
140        .find(|attr| attr.path().is_ident("repr"))
141        .ok_or_else(|| {
142            syn::Error::new_spanned(
143                &s.ast().ident,
144                "Wasmbin enums must have a #[repr(type)] attribute",
145            )
146        })?
147        .parse_args()
148}
149
150fn wasmbin_derive(s: Structure) -> proc_macro2::TokenStream {
151    let (encode_discriminant, decode) = match s.ast().data {
152        syn::Data::Enum(_) => {
153            let repr = syn_try!(parse_repr(&s));
154
155            let mut encode_discriminant = quote!();
156
157            let mut decoders = quote!();
158            let mut decode_other = quote!({ return Ok(None) });
159
160            for v in s.variants() {
161                match v.ast().discriminant {
162                    Some((_, discriminant)) => {
163                        let pat = v.pat();
164
165                        let encode = gen_encode_discriminant(&repr, discriminant);
166                        (quote!(#pat => #encode,)).to_tokens(&mut encode_discriminant);
167
168                        let decode = gen_decode(v);
169                        (quote!(
170                            #discriminant => #decode?,
171                        ))
172                        .to_tokens(&mut decoders);
173                    }
174                    None => {
175                        let fields = v.ast().fields;
176                        if fields.len() != 1 {
177                            syn_throw!(syn::Error::new_spanned(
178                                fields,
179                                "Catch-all variants without discriminant must have a single field."
180                            ));
181                        }
182                        let field = fields.iter().next().unwrap();
183                        let construct = match &field.ident {
184                            Some(ident) => quote!({ #ident: res }),
185                            None => quote!((res)),
186                        };
187                        let variant_name = v.ast().ident;
188                        decode_other = quote! {
189                            if let Some(res) = DecodeWithDiscriminant::maybe_decode_with_discriminant(discriminant, r)? {
190                                Self::#variant_name #construct
191                            } else #decode_other
192                        };
193                    }
194                }
195            }
196
197            (
198                quote! {
199                    match *self {
200                        #encode_discriminant
201                        _ => {}
202                    }
203                },
204                quote! {
205                    gen impl DecodeWithDiscriminant for @Self {
206                        type Discriminant = #repr;
207
208                        fn maybe_decode_with_discriminant(discriminant: #repr, r: &mut impl std::io::Read) -> Result<Option<Self>, DecodeError> {
209                            Ok(Some(match discriminant {
210                                #decoders
211                                _ => #decode_other
212                            }))
213                        }
214                    }
215
216                    gen impl Decode for @Self {
217                        fn decode(r: &mut impl std::io::Read) -> Result<Self, DecodeError> {
218                            DecodeWithDiscriminant::decode_without_discriminant(r)
219                        }
220                    }
221                },
222            )
223        }
224        _ => {
225            let variants = s.variants();
226            assert_eq!(variants.len(), 1);
227            let v = &variants[0];
228            let decode = gen_decode(v);
229            match syn_try!(struct_discriminant(v)) {
230                Some(discriminant) => (
231                    gen_encode_discriminant(&syn::parse_quote!(u8), &discriminant),
232                    quote! {
233                        gen impl DecodeWithDiscriminant for @Self {
234                            type Discriminant = u8;
235
236                            fn maybe_decode_with_discriminant(discriminant: u8, r: &mut impl std::io::Read) -> Result<Option<Self>, DecodeError> {
237                                match discriminant {
238                                    #discriminant => #decode.map(Some),
239                                    _ => Ok(None),
240                                }
241                            }
242                        }
243
244                        gen impl Decode for @Self {
245                            fn decode(r: &mut impl std::io::Read) -> Result<Self, DecodeError> {
246                                DecodeWithDiscriminant::decode_without_discriminant(r)
247                            }
248                        }
249                    },
250                ),
251                None => (
252                    quote! {},
253                    quote! {
254                        gen impl Decode for @Self {
255                            fn decode(r: &mut impl std::io::Read) -> Result<Self, DecodeError> {
256                                #decode
257                            }
258                        }
259                    },
260                ),
261            }
262        }
263    };
264
265    let encode_body = s.each(|bi| {
266        quote! {
267            Encode::encode(#bi, w)?
268        }
269    });
270
271    s.gen_impl(quote! {
272        use crate::io::{Encode, Decode, DecodeWithDiscriminant, DecodeError, PathItem};
273
274        gen impl Encode for @Self {
275            fn encode(&self, w: &mut impl std::io::Write) -> std::io::Result<()> {
276                #encode_discriminant;
277                match *self { #encode_body }
278                Ok(())
279            }
280        }
281
282        #decode
283    })
284}
285
286fn wasmbin_countable_derive(s: Structure) -> proc_macro2::TokenStream {
287    s.gen_impl(quote! {
288        gen impl crate::builtins::WasmbinCountable for @Self {}
289    })
290}
291
292fn wasmbin_visit_derive(mut s: Structure) -> proc_macro2::TokenStream {
293    s.bind_with(|_| synstructure::BindStyle::Move);
294
295    fn generate_visit_body(
296        s: &Structure,
297        method: proc_macro2::TokenStream,
298    ) -> proc_macro2::TokenStream {
299        let body = s.each_variant(|v| {
300            let res = v.bindings().iter().enumerate().map(|(i, bi)| {
301                let res = quote!(Visit::#method(#bi, f));
302                track_err_in_field(res, v, bi.ast(), i)
303            });
304            let mut res = quote!(#(#res?;)*);
305            res = catch_expr(res, quote!(VisitError<VisitE>));
306            res = track_err_in_variant(res, v);
307            quote!(#res?)
308        });
309        quote!(
310            match self { #body }
311            Ok(())
312        )
313    }
314
315    let visit_children_body = generate_visit_body(&s, quote!(visit_child));
316
317    let visit_children_mut_body = generate_visit_body(&s, quote!(visit_child_mut));
318
319    s.gen_impl(quote! {
320        use crate::visit::{Visit, VisitError};
321        use crate::io::PathItem;
322
323        gen impl Visit for @Self where Self: 'static {
324            fn visit_children<'a, VisitT: 'static, VisitE, VisitF: FnMut(&'a VisitT) -> Result<(), VisitE>>(&'a self, f: &mut VisitF) -> Result<(), VisitError<VisitE>> {
325                #visit_children_body
326            }
327
328            fn visit_children_mut<VisitT: 'static, VisitE, VisitF: FnMut(&mut VisitT) -> Result<(), VisitE>>(&mut self, f: &mut VisitF) -> Result<(), VisitError<VisitE>> {
329                #visit_children_mut_body
330            }
331        }
332    })
333}
334
335decl_derive!([Wasmbin, attributes(wasmbin)] => wasmbin_derive);
336decl_derive!([WasmbinCountable] => wasmbin_countable_derive);
337decl_derive!([Visit] => wasmbin_visit_derive);