Skip to main content

wasm_bindgen_test_macro/
lib.rs

1//! See the README for `wasm-bindgen-test` for a bit more info about what's
2//! going on here.
3
4extern crate proc_macro;
5
6use proc_macro2::*;
7use quote::quote;
8use quote::quote_spanned;
9
10#[proc_macro_attribute]
11pub fn wasm_bindgen_bench(
12    attr: proc_macro::TokenStream,
13    body: proc_macro::TokenStream,
14) -> proc_macro::TokenStream {
15    bindgen(attr, body, true)
16}
17
18#[proc_macro_attribute]
19pub fn wasm_bindgen_test(
20    attr: proc_macro::TokenStream,
21    body: proc_macro::TokenStream,
22) -> proc_macro::TokenStream {
23    bindgen(attr, body, false)
24}
25
26fn bindgen(
27    attr: proc_macro::TokenStream,
28    body: proc_macro::TokenStream,
29    is_bench: bool,
30) -> proc_macro::TokenStream {
31    let mut attributes = Attributes::default();
32    let attribute_parser = syn::meta::parser(|meta| attributes.parse(meta));
33
34    syn::parse_macro_input!(attr with attribute_parser);
35    let mut should_panic = None;
36    let mut ignore = None;
37
38    let mut body = TokenStream::from(body).into_iter().peekable();
39
40    // Skip over other attributes to `fn #ident ...`, and extract `#ident`
41    let mut leading_tokens = Vec::new();
42    while let Some(token) = body.next() {
43        match parse_should_panic(&mut body, &token) {
44            Ok(Some((new_should_panic, span))) => {
45                if should_panic.replace(new_should_panic).is_some() {
46                    return compile_error(span, "duplicate `should_panic` attribute");
47                }
48
49                // If we found a `should_panic`, we should skip the `#` and `[...]`.
50                // The `[...]` is skipped here, the current `#` is skipped by using `continue`.
51                body.next();
52                continue;
53            }
54            Ok(None) => (),
55            Err(error) => return error,
56        }
57
58        match parse_ignore(&mut body, &token) {
59            Ok(Some((new_ignore, span))) => {
60                if ignore.replace(new_ignore).is_some() {
61                    return compile_error(span, "duplicate `ignore` attribute");
62                }
63
64                // If we found a `new_ignore`, we should skip the `#` and `[...]`.
65                // The `[...]` is skipped here, the current `#` is skipped by using `continue`.
66                body.next();
67                continue;
68            }
69            Ok(None) => (),
70            Err(error) => return error,
71        }
72
73        leading_tokens.push(token.clone());
74        if let TokenTree::Ident(token) = token {
75            if token == "async" {
76                attributes.r#async = true;
77            }
78            if token == "fn" {
79                break;
80            }
81        }
82    }
83    let ident = find_ident(&mut body).expect("expected a function name");
84
85    let mut tokens = Vec::<TokenTree>::new();
86
87    let should_panic_par = match &should_panic {
88        Some(Some(lit)) => {
89            quote! { ::core::option::Option::Some(::core::option::Option::Some(#lit)) }
90        }
91        Some(None) => quote! { ::core::option::Option::Some(::core::option::Option::None) },
92        None => quote! { ::core::option::Option::None },
93    };
94
95    let ignore_par = match &ignore {
96        Some(Some(lit)) => {
97            quote! { ::core::option::Option::Some(::core::option::Option::Some(#lit)) }
98        }
99        Some(None) => quote! { ::core::option::Option::Some(::core::option::Option::None) },
100        None => quote! { ::core::option::Option::None },
101    };
102
103    let exec_ident = if is_bench {
104        let body = if attributes.r#async {
105            quote! { #ident(&mut bencher).await; }
106        } else {
107            quote! { #ident(&mut bencher); }
108        };
109        let bench_ident = quote::format_ident!("__wbg_bench_{ident}");
110        tokens.extend(quote! {
111            async fn #bench_ident() {
112                let mut bencher = Criterion::default()
113                    .with_location(file!(), module_path!());
114                #body
115            }
116        });
117        bench_ident
118    } else {
119        ident.clone()
120    };
121
122    let test_body = if attributes.r#async || is_bench {
123        quote! { cx.execute_async(test_name, #exec_ident, #should_panic_par, #ignore_par); }
124    } else {
125        quote! { cx.execute_sync(test_name, #exec_ident, #should_panic_par, #ignore_par); }
126    };
127
128    let ignore_name = if ignore.is_some() { "$" } else { "" };
129
130    let wasm_bindgen_path = attributes.wasm_bindgen_path;
131    let prefix = if is_bench { "__wbgb_" } else { "__wbgt_" };
132    tokens.extend(
133        quote! {
134            const _: () = {
135                #wasm_bindgen_path::__rt::wasm_bindgen::__wbindgen_coverage! {
136                #[export_name = ::core::concat!(#prefix, #ignore_name, "_", ::core::module_path!(), "::", ::core::stringify!(#ident))]
137                #[cfg(target_family = "wasm")]
138                extern "C" fn __wbgt_test(cx: &#wasm_bindgen_path::__rt::Context) {
139                    let test_name = ::core::concat!(::core::module_path!(), "::", ::core::stringify!(#ident));
140                    #test_body
141                }
142                }
143            };
144        },
145    );
146
147    if let Some(path) = attributes.unsupported {
148        tokens.extend(quote! { #[cfg_attr(not(target_family = "wasm"), #path)] });
149
150        if let Some(should_panic) = should_panic {
151            let should_panic = if let Some(lit) = should_panic {
152                quote! { should_panic = #lit }
153            } else {
154                quote! { should_panic }
155            };
156
157            tokens.extend(quote! { #[cfg_attr(not(target_family = "wasm"), #should_panic)] })
158        }
159
160        if let Some(ignore) = ignore {
161            let ignore = if let Some(lit) = ignore {
162                quote! { ignore = #lit }
163            } else {
164                quote! { ignore }
165            };
166
167            tokens.extend(quote! { #[cfg_attr(not(target_family = "wasm"), #ignore)] })
168        }
169    } else {
170        tokens.extend(quote! {
171            #[cfg_attr(not(target_family = "wasm"), allow(dead_code))]
172        });
173    }
174
175    tokens.extend(leading_tokens);
176    tokens.push(ident.into());
177    tokens.extend(body);
178
179    tokens.into_iter().collect::<TokenStream>().into()
180}
181
182fn parse_should_panic(
183    body: &mut std::iter::Peekable<token_stream::IntoIter>,
184    token: &TokenTree,
185) -> Result<Option<(Option<Literal>, Span)>, proc_macro::TokenStream> {
186    // Start by parsing the `#`
187    match token {
188        TokenTree::Punct(op) if op.as_char() == '#' => (),
189        _ => return Ok(None),
190    }
191
192    // Parse `[...]`
193    let group = match body.peek() {
194        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Bracket => group,
195        _ => return Ok(None),
196    };
197
198    let mut stream = group.stream().into_iter();
199
200    // Parse `should_panic`
201    let mut span = match stream.next() {
202        Some(TokenTree::Ident(token)) if token == "should_panic" => token.span(),
203        _ => return Ok(None),
204    };
205
206    let should_panic = span;
207
208    // We are interested in the `expected` attribute or string if there is any
209    match stream.next() {
210        // Parse the `(...)` in `#[should_panic(...)]`
211        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Parenthesis => {
212            let span = group.span();
213            stream = group.stream().into_iter();
214
215            // Parse `expected`
216            match stream.next() {
217                Some(TokenTree::Ident(token)) if token == "expected" => (),
218                _ => {
219                    return Err(compile_error(
220                        span,
221                        "malformed `#[should_panic(...)]` attribute",
222                    ))
223                }
224            }
225
226            // Parse `=`
227            match stream.next() {
228                Some(TokenTree::Punct(op)) if op.as_char() == '=' => (),
229                _ => {
230                    return Err(compile_error(
231                        span,
232                        "malformed `#[should_panic(...)]` attribute",
233                    ))
234                }
235            }
236        }
237        // Parse `=`
238        Some(TokenTree::Punct(op)) if op.as_char() == '=' => (),
239        Some(token) => {
240            return Err(compile_error(
241                token.span(),
242                "malformed `#[should_panic = \"...\"]` attribute",
243            ))
244        }
245        None => {
246            return Ok(Some((None, should_panic)));
247        }
248    }
249
250    // Parse string in `#[should_panic(expected = "string")]` or `#[should_panic = "string"]`
251    if let Some(TokenTree::Literal(lit)) = stream.next() {
252        span = lit.span();
253        let string = lit.to_string();
254
255        // Verify it's a string.
256        if string.starts_with('"') && string.ends_with('"') {
257            return Ok(Some((Some(lit), should_panic)));
258        }
259    }
260
261    Err(compile_error(span, "malformed `#[should_panic]` attribute"))
262}
263
264fn parse_ignore(
265    body: &mut std::iter::Peekable<token_stream::IntoIter>,
266    token: &TokenTree,
267) -> Result<Option<(Option<Literal>, Span)>, proc_macro::TokenStream> {
268    // Start by parsing the `#`
269    match token {
270        TokenTree::Punct(op) if op.as_char() == '#' => (),
271        _ => return Ok(None),
272    }
273
274    // Parse `[...]`
275    let group = match body.peek() {
276        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Bracket => group,
277        _ => return Ok(None),
278    };
279
280    let mut stream = group.stream().into_iter();
281
282    // Parse `ignore`
283    let mut span = match stream.next() {
284        Some(TokenTree::Ident(token)) if token == "ignore" => token.span(),
285        _ => return Ok(None),
286    };
287
288    let ignore = span;
289
290    // We are interested in the reason string if there is any
291    match stream.next() {
292        // Parse `=`
293        Some(TokenTree::Punct(op)) if op.as_char() == '=' => (),
294        Some(token) => {
295            return Err(compile_error(
296                token.span(),
297                "malformed `#[ignore = \"...\"]` attribute",
298            ))
299        }
300        None => {
301            return Ok(Some((None, ignore)));
302        }
303    }
304
305    // Parse string in `#[ignore = "string"]`
306    if let Some(TokenTree::Literal(lit)) = stream.next() {
307        span = lit.span();
308        let string = lit.to_string();
309
310        // Verify it's a string.
311        if string.starts_with('"') && string.ends_with('"') {
312            return Ok(Some((Some(lit), ignore)));
313        }
314    }
315
316    Err(compile_error(span, "malformed `#[ignore]` attribute"))
317}
318
319fn find_ident(iter: &mut impl Iterator<Item = TokenTree>) -> Option<Ident> {
320    match iter.next()? {
321        TokenTree::Ident(i) => Some(i),
322        TokenTree::Group(g) if g.delimiter() == Delimiter::None => {
323            find_ident(&mut g.stream().into_iter())
324        }
325        _ => None,
326    }
327}
328
329fn compile_error(span: Span, msg: &str) -> proc_macro::TokenStream {
330    quote_spanned! { span => compile_error!(#msg); }.into()
331}
332
333struct Attributes {
334    r#async: bool,
335    wasm_bindgen_path: syn::Path,
336    unsupported: Option<syn::Meta>,
337}
338
339impl Default for Attributes {
340    fn default() -> Self {
341        Self {
342            r#async: false,
343            wasm_bindgen_path: syn::parse_quote!(::wasm_bindgen_test),
344            unsupported: None,
345        }
346    }
347}
348
349impl Attributes {
350    fn parse(&mut self, meta: syn::meta::ParseNestedMeta) -> syn::parse::Result<()> {
351        if meta.path.is_ident("async") {
352            self.r#async = true;
353        } else if meta.path.is_ident("crate") {
354            self.wasm_bindgen_path = meta.value()?.parse::<syn::Path>()?;
355        } else if meta.path.is_ident("unsupported") {
356            self.unsupported = Some(meta.value()?.parse::<syn::Meta>()?);
357        } else {
358            return Err(meta.error("unknown attribute"));
359        }
360        Ok(())
361    }
362}