test_fork_macros/
lib.rs

1// Copyright (C) 2025 Daniel Mueller <deso@posteo.net>
2// SPDX-License-Identifier: (Apache-2.0 OR MIT)
3
4#![cfg_attr(docsrs, feature(doc_cfg))]
5
6extern crate proc_macro;
7
8use std::ops::Deref as _;
9
10use proc_macro::TokenStream;
11use proc_macro2::Ident;
12use proc_macro2::Span;
13use proc_macro2::TokenStream as Tokens;
14
15use quote::quote;
16use quote::ToTokens as _;
17
18use syn::parse_macro_input;
19use syn::Attribute;
20use syn::Error;
21use syn::FnArg;
22use syn::ItemFn;
23use syn::Pat;
24use syn::Result;
25use syn::ReturnType;
26use syn::Signature;
27use syn::Type;
28
29
30#[derive(Debug)]
31enum Kind {
32    Test,
33    Bench,
34}
35
36impl Kind {
37    #[inline]
38    fn as_str(&self) -> &str {
39        match self {
40            Self::Test => "test",
41            Self::Bench => "bench",
42        }
43    }
44}
45
46
47/// A procedural macro for running a test in a separate process.
48///
49/// # Example
50///
51/// Use the attribute for all tests in scope:
52/// ```rust,ignore
53/// use test_fork::test;
54///
55/// #[test]
56/// fn test1() {
57///   assert_eq!(2 + 2, 4);
58/// }
59/// ```
60///
61/// Use it only on a single test:
62/// ```rust,ignore
63/// #[test_fork::test]
64/// fn test2() {
65///   assert_eq!(2 + 3, 5);
66/// }
67/// ```
68#[proc_macro_attribute]
69pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
70    let input_fn = parse_macro_input!(item as ItemFn);
71
72    let has_test = input_fn
73        .attrs
74        .iter()
75        .any(|attr| is_attribute_kind(Kind::Test, attr));
76    let inner_test = if has_test {
77        quote! {}
78    } else {
79        quote! { #[::core::prelude::v1::test] }
80    };
81
82    try_test(attr, input_fn, inner_test)
83        .unwrap_or_else(syn::Error::into_compile_error)
84        .into()
85}
86
87
88/// A procedural macro for running a benchmark in a separate process.
89///
90/// # Example
91///
92/// Use the attribute for all benchmarks in scope:
93/// ```rust,ignore
94/// use test_fork::bench;
95///
96/// #[bench]
97/// fn bench1(b: &mut Bencher) {
98///   b.iter(|| sleep(Duration::from_millis(1)));
99/// }
100/// ```
101///
102/// Use it only on a single benchmark:
103/// ```rust,ignore
104/// #[test_fork::bench]
105/// fn bench2(b: &mut Bencher) {
106///   b.iter(|| sleep(Duration::from_millis(1)));
107/// }
108#[cfg(all(feature = "unstable", feature = "unsound"))]
109#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", feature = "unsound"))))]
110#[proc_macro_attribute]
111pub fn bench(attr: TokenStream, item: TokenStream) -> TokenStream {
112    let input_fn = parse_macro_input!(item as ItemFn);
113
114    let has_bench = input_fn
115        .attrs
116        .iter()
117        .any(|attr| is_attribute_kind(Kind::Bench, attr));
118    let inner_bench = if has_bench {
119        quote! {}
120    } else {
121        quote! { #[::core::prelude::v1::bench] }
122    };
123
124    try_bench(attr, input_fn, inner_bench)
125        .unwrap_or_else(syn::Error::into_compile_error)
126        .into()
127}
128
129
130/// A procedural macro for running a test or benchmark in a separate
131/// process.
132///
133/// This attribute is able to cater to both tests and benchmarks, while
134/// #[[macro@test]] is specific to tests and #[[macro@bench]] to
135/// benchmarks.
136///
137/// Contrary to both, this attribute does not in itself make a function
138/// a test/benchmark, so it will *always* have to be combined with an
139/// additional "inner" attribute. However, it can be more convenient for
140/// annotating only a sub-set of tests/benchmarks for running in
141/// separate processes, especially when non-standard attributes are
142/// involved:
143///
144/// # Example
145///
146/// ```rust,ignore
147/// use test_fork::fork;
148///
149/// #[fork]
150/// #[test]
151/// fn test3() {
152///   assert_eq!(2 + 4, 6);
153/// }
154///
155/// #[fork]
156/// #[bench]
157/// fn bench3(b: &mut Bencher) {
158///   b.iter(|| sleep(Duration::from_millis(1)));
159/// }
160/// ```
161#[proc_macro_attribute]
162pub fn fork(attr: TokenStream, item: TokenStream) -> TokenStream {
163    let supports_bench = cfg!(all(feature = "unstable", feature = "unsound"));
164    let input_fn = parse_macro_input!(item as ItemFn);
165
166    let has_test = input_fn
167        .attrs
168        .iter()
169        .any(|attr| is_attribute_kind(Kind::Test, attr));
170    let has_bench = supports_bench
171        && input_fn
172            .attrs
173            .iter()
174            .any(|attr| is_attribute_kind(Kind::Bench, attr));
175
176    let inner_attr = quote! {};
177    if has_test {
178        try_test(attr, input_fn, inner_attr)
179    } else if has_bench {
180        try_bench(attr, input_fn, inner_attr)
181    } else {
182        let inner_attr = if parse_bench_sig(&input_fn.sig).is_some() {
183            "#[bench]"
184        } else {
185            "#[test]"
186        };
187
188        Err(Error::new_spanned(
189            Tokens::from(attr),
190            format!("test_fork::fork requires an inner {inner_attr} attribute"),
191        ))
192    }
193    .unwrap_or_else(syn::Error::into_compile_error)
194    .into()
195}
196
197
198/// Check whether given attribute is a test or bench attribute of the
199/// form:
200/// - `#[<kind>]`
201/// - `#[core::prelude::*::<kind>]` or `#[::core::prelude::*::<kind>]`
202/// - `#[std::prelude::*::<kind>]` or `#[::std::prelude::*::<kind>]`
203fn is_attribute_kind(kind: Kind, attr: &Attribute) -> bool {
204    let path = match &attr.meta {
205        syn::Meta::Path(path) => path,
206        _ => return false,
207    };
208    let candidates = [
209        ["core", "prelude", "*", kind.as_str()],
210        ["std", "prelude", "*", kind.as_str()],
211    ];
212    if path.leading_colon.is_none()
213        && path.segments.len() == 1
214        && path.segments[0].arguments.is_none()
215        && path.segments[0].ident == kind.as_str()
216    {
217        return true;
218    } else if path.segments.len() != candidates[0].len() {
219        return false;
220    }
221    candidates.into_iter().any(|segments| {
222        path.segments.iter().zip(segments).all(|(segment, path)| {
223            segment.arguments.is_none() && (path == "*" || segment.ident == path)
224        })
225    })
226}
227
228fn try_test(attr: TokenStream, input_fn: ItemFn, inner_test: Tokens) -> Result<Tokens> {
229    if !attr.is_empty() {
230        return Err(Error::new_spanned(
231            Tokens::from(attr),
232            "the attribute does not currently accept arguments",
233        ))
234    }
235
236    let ItemFn {
237        attrs,
238        vis,
239        mut sig,
240        block,
241    } = input_fn;
242
243    let test_name = sig.ident.clone();
244    let mut body_fn_sig = sig.clone();
245    body_fn_sig.ident = Ident::new("body_fn", Span::call_site());
246    // Our tests currently basically have to return (), because we don't
247    // have a good way of conveying the result back from the child
248    // process.
249    sig.output = ReturnType::Default;
250
251    let augmented_test = quote! {
252        #inner_test
253        #(#attrs)*
254        #vis #sig {
255            #body_fn_sig
256            #block
257
258            ::test_fork::test_fork_core::fork(
259                ::test_fork::test_fork_core::fork_id!(),
260                ::test_fork::test_fork_core::fork_test_name!(#test_name),
261                body_fn as fn() -> _,
262            ).expect("forking test failed")
263        }
264    };
265
266    Ok(augmented_test)
267}
268
269fn parse_bench_sig(sig: &Signature) -> Option<(Pat, Type)> {
270    if sig.inputs.len() != 1 {
271        return None
272    }
273
274    if let FnArg::Typed(pat_type) = sig.inputs.first().unwrap() {
275        let ty = match pat_type.ty.deref() {
276            Type::Reference(ty_ref) => ty_ref.elem.clone(),
277            _ => return None,
278        };
279        Some((*pat_type.pat.clone(), *ty))
280    } else {
281        None
282    }
283}
284
285fn try_bench(attr: TokenStream, input_fn: ItemFn, inner_bench: Tokens) -> Result<Tokens> {
286    if !attr.is_empty() {
287        return Err(Error::new_spanned(
288            Tokens::from(attr),
289            "the attribute does not currently accept arguments",
290        ))
291    }
292
293    let ItemFn {
294        attrs,
295        vis,
296        mut sig,
297        block,
298    } = input_fn;
299
300    let (bencher_name, bencher_ty) = parse_bench_sig(&sig).ok_or_else(|| {
301        Error::new_spanned(
302            sig.to_token_stream(),
303            "benchmark function has unexpected signature (expected single `&mut Bencher` argument)",
304        )
305    })?;
306
307    let test_name = sig.ident.clone();
308    let mut body_fn_sig = sig.clone();
309    body_fn_sig.ident = Ident::new("body_fn", Span::call_site());
310    sig.output = ReturnType::Default;
311
312    let augmented_bench = quote! {
313        #inner_bench
314        #(#attrs)*
315        #vis #sig {
316            #body_fn_sig
317            #block
318
319            use ::std::mem::size_of;
320            use ::std::mem::transmute;
321
322            type BencherBuf = [u8; size_of::<#bencher_ty>()];
323
324            // SAFETY: Probably unsound. We can't guarantee that the
325            //         `Bencher` type is just a bunch of bytes that we
326            //         can copy around. And yet, that's the best we can
327            //         do.
328            let buf_ref = unsafe {
329                transmute::<&mut #bencher_ty, &mut BencherBuf>(#bencher_name)
330            };
331
332            fn wrapper_fn(buf_ref: &mut [u8]) {
333                let buf_ref = <&mut BencherBuf>::try_from(buf_ref).unwrap();
334                // SAFETY: See above.
335                let bench_ref = unsafe {
336                    transmute::<&mut BencherBuf, &mut #bencher_ty>(buf_ref)
337                };
338                let () = body_fn(bench_ref);
339            }
340
341            ::test_fork::test_fork_core::fork_in_out(
342                ::test_fork::test_fork_core::fork_id!(),
343                ::test_fork::test_fork_core::fork_test_name!(#test_name),
344                wrapper_fn as fn(&mut [u8]) -> _,
345                buf_ref,
346            ).expect("forking test failed")
347        }
348    };
349
350    Ok(augmented_bench)
351}