1extern crate proc_macro;
16
17use quote::{quote, ToTokens};
18use std::borrow::Cow;
19use synstructure::{decl_derive, Structure, VariantInfo};
20
21macro_rules! syn_throw {
22 ($err:expr) => {
23 return syn::Error::to_compile_error(&$err)
24 };
25}
26
27macro_rules! syn_try {
28 ($expr:expr) => {
29 match $expr {
30 Ok(expr) => expr,
31 Err(err) => syn_throw!(err),
32 }
33 };
34}
35
36fn struct_discriminant<'v>(v: &VariantInfo<'v>) -> syn::Result<Option<Cow<'v, syn::Expr>>> {
37 v.ast()
38 .attrs
39 .iter()
40 .filter_map(|attr| match attr {
41 syn::Attribute {
42 style: syn::AttrStyle::Outer,
43 meta,
44 ..
45 } if meta.path().is_ident("wasmbin") => {
46 syn::custom_keyword!(discriminant);
47
48 Some(
49 attr.parse_args_with(|parser: syn::parse::ParseStream| {
50 parser.parse::<discriminant>()?;
51 parser.parse::<syn::Token![=]>()?;
52 parser.parse()
53 })
54 .map(Cow::Owned),
55 )
56 }
57 _ => None,
58 })
59 .try_fold(None, |prev, discriminant| {
60 let discriminant = discriminant?;
61 if let Some(prev) = prev {
62 let mut err = syn::Error::new_spanned(
63 discriminant,
64 "#[derive(Wasmbin)]: duplicate discriminant",
65 );
66 err.combine(syn::Error::new_spanned(
67 prev,
68 "#[derive(Wasmbin)]: previous discriminant here",
69 ));
70 return Err(err);
71 }
72 Ok(Some(discriminant))
73 })
74}
75
76fn gen_encode_discriminant(repr: &syn::Type, discriminant: &syn::Expr) -> proc_macro2::TokenStream {
77 quote!(<#repr as Encode>::encode(&#discriminant, w)?)
78}
79
80fn is_newtype_like(v: &VariantInfo) -> bool {
81 matches!(v.ast().fields, fields @ syn::Fields::Unnamed(_) if fields.len() == 1)
82}
83
84fn track_err_in_field(
85 mut res: proc_macro2::TokenStream,
86 v: &VariantInfo,
87 field: &syn::Field,
88 index: usize,
89) -> proc_macro2::TokenStream {
90 if !is_newtype_like(v) {
91 let field_name = match &field.ident {
92 Some(ident) => ident.to_string(),
93 None => index.to_string(),
94 };
95 res = quote!(#res.map_err(|err| err.in_path(PathItem::Name(#field_name))));
96 }
97 res
98}
99
100fn track_err_in_variant(
101 res: proc_macro2::TokenStream,
102 v: &VariantInfo,
103) -> proc_macro2::TokenStream {
104 use std::fmt::Write;
105
106 let mut variant_name = String::new();
107 if let Some(prefix) = v.prefix {
108 write!(variant_name, "{}::", prefix).unwrap();
109 }
110 write!(variant_name, "{}", v.ast().ident).unwrap();
111
112 quote!(#res.map_err(|err| err.in_path(PathItem::Variant(#variant_name))))
113}
114
115fn catch_expr(
116 res: proc_macro2::TokenStream,
117 err: proc_macro2::TokenStream,
118) -> proc_macro2::TokenStream {
119 quote!(
120 (move || -> Result<_, #err> {
121 Ok({ #res })
122 })()
123 )
124}
125
126fn gen_decode(v: &VariantInfo) -> proc_macro2::TokenStream {
127 let mut res = v.construct(|field, index| {
128 let res = track_err_in_field(quote!(Decode::decode(r)), v, field, index);
129 quote!(#res?)
130 });
131 res = catch_expr(res, quote!(DecodeError));
132 res = track_err_in_variant(res, v);
133 res
134}
135
136fn parse_repr(s: &Structure) -> syn::Result<syn::Type> {
137 s.ast()
138 .attrs
139 .iter()
140 .find(|attr| attr.path().is_ident("repr"))
141 .ok_or_else(|| {
142 syn::Error::new_spanned(
143 &s.ast().ident,
144 "Wasmbin enums must have a #[repr(type)] attribute",
145 )
146 })?
147 .parse_args()
148}
149
150fn wasmbin_derive(s: Structure) -> proc_macro2::TokenStream {
151 let (encode_discriminant, decode) = match s.ast().data {
152 syn::Data::Enum(_) => {
153 let repr = syn_try!(parse_repr(&s));
154
155 let mut encode_discriminant = quote!();
156
157 let mut decoders = quote!();
158 let mut decode_other = quote!({ return Ok(None) });
159
160 for v in s.variants() {
161 match v.ast().discriminant {
162 Some((_, discriminant)) => {
163 let pat = v.pat();
164
165 let encode = gen_encode_discriminant(&repr, discriminant);
166 (quote!(#pat => #encode,)).to_tokens(&mut encode_discriminant);
167
168 let decode = gen_decode(v);
169 (quote!(
170 #discriminant => #decode?,
171 ))
172 .to_tokens(&mut decoders);
173 }
174 None => {
175 let fields = v.ast().fields;
176 if fields.len() != 1 {
177 syn_throw!(syn::Error::new_spanned(
178 fields,
179 "Catch-all variants without discriminant must have a single field."
180 ));
181 }
182 let field = fields.iter().next().unwrap();
183 let construct = match &field.ident {
184 Some(ident) => quote!({ #ident: res }),
185 None => quote!((res)),
186 };
187 let variant_name = v.ast().ident;
188 decode_other = quote! {
189 if let Some(res) = DecodeWithDiscriminant::maybe_decode_with_discriminant(discriminant, r)? {
190 Self::#variant_name #construct
191 } else #decode_other
192 };
193 }
194 }
195 }
196
197 (
198 quote! {
199 match *self {
200 #encode_discriminant
201 _ => {}
202 }
203 },
204 quote! {
205 gen impl DecodeWithDiscriminant for @Self {
206 type Discriminant = #repr;
207
208 fn maybe_decode_with_discriminant(discriminant: #repr, r: &mut impl std::io::Read) -> Result<Option<Self>, DecodeError> {
209 Ok(Some(match discriminant {
210 #decoders
211 _ => #decode_other
212 }))
213 }
214 }
215
216 gen impl Decode for @Self {
217 fn decode(r: &mut impl std::io::Read) -> Result<Self, DecodeError> {
218 DecodeWithDiscriminant::decode_without_discriminant(r)
219 }
220 }
221 },
222 )
223 }
224 _ => {
225 let variants = s.variants();
226 assert_eq!(variants.len(), 1);
227 let v = &variants[0];
228 let decode = gen_decode(v);
229 match syn_try!(struct_discriminant(v)) {
230 Some(discriminant) => (
231 gen_encode_discriminant(&syn::parse_quote!(u8), &discriminant),
232 quote! {
233 gen impl DecodeWithDiscriminant for @Self {
234 type Discriminant = u8;
235
236 fn maybe_decode_with_discriminant(discriminant: u8, r: &mut impl std::io::Read) -> Result<Option<Self>, DecodeError> {
237 match discriminant {
238 #discriminant => #decode.map(Some),
239 _ => Ok(None),
240 }
241 }
242 }
243
244 gen impl Decode for @Self {
245 fn decode(r: &mut impl std::io::Read) -> Result<Self, DecodeError> {
246 DecodeWithDiscriminant::decode_without_discriminant(r)
247 }
248 }
249 },
250 ),
251 None => (
252 quote! {},
253 quote! {
254 gen impl Decode for @Self {
255 fn decode(r: &mut impl std::io::Read) -> Result<Self, DecodeError> {
256 #decode
257 }
258 }
259 },
260 ),
261 }
262 }
263 };
264
265 let encode_body = s.each(|bi| {
266 quote! {
267 Encode::encode(#bi, w)?
268 }
269 });
270
271 s.gen_impl(quote! {
272 use crate::io::{Encode, Decode, DecodeWithDiscriminant, DecodeError, PathItem};
273
274 gen impl Encode for @Self {
275 fn encode(&self, w: &mut impl std::io::Write) -> std::io::Result<()> {
276 #encode_discriminant;
277 match *self { #encode_body }
278 Ok(())
279 }
280 }
281
282 #decode
283 })
284}
285
286fn wasmbin_countable_derive(s: Structure) -> proc_macro2::TokenStream {
287 s.gen_impl(quote! {
288 gen impl crate::builtins::WasmbinCountable for @Self {}
289 })
290}
291
292fn wasmbin_visit_derive(mut s: Structure) -> proc_macro2::TokenStream {
293 s.bind_with(|_| synstructure::BindStyle::Move);
294
295 fn generate_visit_body(
296 s: &Structure,
297 method: proc_macro2::TokenStream,
298 ) -> proc_macro2::TokenStream {
299 let body = s.each_variant(|v| {
300 let res = v.bindings().iter().enumerate().map(|(i, bi)| {
301 let res = quote!(Visit::#method(#bi, f));
302 track_err_in_field(res, v, bi.ast(), i)
303 });
304 let mut res = quote!(#(#res?;)*);
305 res = catch_expr(res, quote!(VisitError<VisitE>));
306 res = track_err_in_variant(res, v);
307 quote!(#res?)
308 });
309 quote!(
310 match self { #body }
311 Ok(())
312 )
313 }
314
315 let visit_children_body = generate_visit_body(&s, quote!(visit_child));
316
317 let visit_children_mut_body = generate_visit_body(&s, quote!(visit_child_mut));
318
319 s.gen_impl(quote! {
320 use crate::visit::{Visit, VisitError};
321 use crate::io::PathItem;
322
323 gen impl Visit for @Self where Self: 'static {
324 fn visit_children<'a, VisitT: 'static, VisitE, VisitF: FnMut(&'a VisitT) -> Result<(), VisitE>>(&'a self, f: &mut VisitF) -> Result<(), VisitError<VisitE>> {
325 #visit_children_body
326 }
327
328 fn visit_children_mut<VisitT: 'static, VisitE, VisitF: FnMut(&mut VisitT) -> Result<(), VisitE>>(&mut self, f: &mut VisitF) -> Result<(), VisitError<VisitE>> {
329 #visit_children_mut_body
330 }
331 }
332 })
333}
334
335decl_derive!([Wasmbin, attributes(wasmbin)] => wasmbin_derive);
336decl_derive!([WasmbinCountable] => wasmbin_countable_derive);
337decl_derive!([Visit] => wasmbin_visit_derive);