1#![cfg_attr(docsrs, feature(doc_cfg))]
5
6extern crate proc_macro;
7
8use std::ops::Deref as _;
9
10use proc_macro::TokenStream;
11use proc_macro2::Ident;
12use proc_macro2::Span;
13use proc_macro2::TokenStream as Tokens;
14
15use quote::quote;
16use quote::ToTokens as _;
17
18use syn::parse_macro_input;
19use syn::Attribute;
20use syn::Error;
21use syn::FnArg;
22use syn::ItemFn;
23use syn::Pat;
24use syn::Result;
25use syn::ReturnType;
26use syn::Signature;
27use syn::Type;
28
29
30#[derive(Debug)]
31enum Kind {
32 Test,
33 Bench,
34}
35
36impl Kind {
37 #[inline]
38 fn as_str(&self) -> &str {
39 match self {
40 Self::Test => "test",
41 Self::Bench => "bench",
42 }
43 }
44}
45
46
47#[proc_macro_attribute]
69pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
70 let input_fn = parse_macro_input!(item as ItemFn);
71
72 let has_test = input_fn
73 .attrs
74 .iter()
75 .any(|attr| is_attribute_kind(Kind::Test, attr));
76 let inner_test = if has_test {
77 quote! {}
78 } else {
79 quote! { #[::core::prelude::v1::test] }
80 };
81
82 try_test(attr, input_fn, inner_test)
83 .unwrap_or_else(syn::Error::into_compile_error)
84 .into()
85}
86
87
88#[cfg(all(feature = "unstable", feature = "unsound"))]
109#[cfg_attr(docsrs, doc(cfg(all(feature = "unstable", feature = "unsound"))))]
110#[proc_macro_attribute]
111pub fn bench(attr: TokenStream, item: TokenStream) -> TokenStream {
112 let input_fn = parse_macro_input!(item as ItemFn);
113
114 let has_bench = input_fn
115 .attrs
116 .iter()
117 .any(|attr| is_attribute_kind(Kind::Bench, attr));
118 let inner_bench = if has_bench {
119 quote! {}
120 } else {
121 quote! { #[::core::prelude::v1::bench] }
122 };
123
124 try_bench(attr, input_fn, inner_bench)
125 .unwrap_or_else(syn::Error::into_compile_error)
126 .into()
127}
128
129
130#[proc_macro_attribute]
162pub fn fork(attr: TokenStream, item: TokenStream) -> TokenStream {
163 let supports_bench = cfg!(all(feature = "unstable", feature = "unsound"));
164 let input_fn = parse_macro_input!(item as ItemFn);
165
166 let has_test = input_fn
167 .attrs
168 .iter()
169 .any(|attr| is_attribute_kind(Kind::Test, attr));
170 let has_bench = supports_bench
171 && input_fn
172 .attrs
173 .iter()
174 .any(|attr| is_attribute_kind(Kind::Bench, attr));
175
176 let inner_attr = quote! {};
177 if has_test {
178 try_test(attr, input_fn, inner_attr)
179 } else if has_bench {
180 try_bench(attr, input_fn, inner_attr)
181 } else {
182 let inner_attr = if parse_bench_sig(&input_fn.sig).is_some() {
183 "#[bench]"
184 } else {
185 "#[test]"
186 };
187
188 Err(Error::new_spanned(
189 Tokens::from(attr),
190 format!("test_fork::fork requires an inner {inner_attr} attribute"),
191 ))
192 }
193 .unwrap_or_else(syn::Error::into_compile_error)
194 .into()
195}
196
197
198fn is_attribute_kind(kind: Kind, attr: &Attribute) -> bool {
204 let path = match &attr.meta {
205 syn::Meta::Path(path) => path,
206 _ => return false,
207 };
208 let candidates = [
209 ["core", "prelude", "*", kind.as_str()],
210 ["std", "prelude", "*", kind.as_str()],
211 ];
212 if path.leading_colon.is_none()
213 && path.segments.len() == 1
214 && path.segments[0].arguments.is_none()
215 && path.segments[0].ident == kind.as_str()
216 {
217 return true;
218 } else if path.segments.len() != candidates[0].len() {
219 return false;
220 }
221 candidates.into_iter().any(|segments| {
222 path.segments.iter().zip(segments).all(|(segment, path)| {
223 segment.arguments.is_none() && (path == "*" || segment.ident == path)
224 })
225 })
226}
227
228fn try_test(attr: TokenStream, input_fn: ItemFn, inner_test: Tokens) -> Result<Tokens> {
229 if !attr.is_empty() {
230 return Err(Error::new_spanned(
231 Tokens::from(attr),
232 "the attribute does not currently accept arguments",
233 ))
234 }
235
236 let ItemFn {
237 attrs,
238 vis,
239 mut sig,
240 block,
241 } = input_fn;
242
243 let test_name = sig.ident.clone();
244 let mut body_fn_sig = sig.clone();
245 body_fn_sig.ident = Ident::new("body_fn", Span::call_site());
246 sig.output = ReturnType::Default;
250
251 let augmented_test = quote! {
252 #inner_test
253 #(#attrs)*
254 #vis #sig {
255 #body_fn_sig
256 #block
257
258 ::test_fork::test_fork_core::fork(
259 ::test_fork::test_fork_core::fork_id!(),
260 ::test_fork::test_fork_core::fork_test_name!(#test_name),
261 body_fn as fn() -> _,
262 ).expect("forking test failed")
263 }
264 };
265
266 Ok(augmented_test)
267}
268
269fn parse_bench_sig(sig: &Signature) -> Option<(Pat, Type)> {
270 if sig.inputs.len() != 1 {
271 return None
272 }
273
274 if let FnArg::Typed(pat_type) = sig.inputs.first().unwrap() {
275 let ty = match pat_type.ty.deref() {
276 Type::Reference(ty_ref) => ty_ref.elem.clone(),
277 _ => return None,
278 };
279 Some((*pat_type.pat.clone(), *ty))
280 } else {
281 None
282 }
283}
284
285fn try_bench(attr: TokenStream, input_fn: ItemFn, inner_bench: Tokens) -> Result<Tokens> {
286 if !attr.is_empty() {
287 return Err(Error::new_spanned(
288 Tokens::from(attr),
289 "the attribute does not currently accept arguments",
290 ))
291 }
292
293 let ItemFn {
294 attrs,
295 vis,
296 mut sig,
297 block,
298 } = input_fn;
299
300 let (bencher_name, bencher_ty) = parse_bench_sig(&sig).ok_or_else(|| {
301 Error::new_spanned(
302 sig.to_token_stream(),
303 "benchmark function has unexpected signature (expected single `&mut Bencher` argument)",
304 )
305 })?;
306
307 let test_name = sig.ident.clone();
308 let mut body_fn_sig = sig.clone();
309 body_fn_sig.ident = Ident::new("body_fn", Span::call_site());
310 sig.output = ReturnType::Default;
311
312 let augmented_bench = quote! {
313 #inner_bench
314 #(#attrs)*
315 #vis #sig {
316 #body_fn_sig
317 #block
318
319 use ::std::mem::size_of;
320 use ::std::mem::transmute;
321
322 type BencherBuf = [u8; size_of::<#bencher_ty>()];
323
324 let buf_ref = unsafe {
329 transmute::<&mut #bencher_ty, &mut BencherBuf>(#bencher_name)
330 };
331
332 fn wrapper_fn(buf_ref: &mut [u8]) {
333 let buf_ref = <&mut BencherBuf>::try_from(buf_ref).unwrap();
334 let bench_ref = unsafe {
336 transmute::<&mut BencherBuf, &mut #bencher_ty>(buf_ref)
337 };
338 let () = body_fn(bench_ref);
339 }
340
341 ::test_fork::test_fork_core::fork_in_out(
342 ::test_fork::test_fork_core::fork_id!(),
343 ::test_fork::test_fork_core::fork_test_name!(#test_name),
344 wrapper_fn as fn(&mut [u8]) -> _,
345 buf_ref,
346 ).expect("forking test failed")
347 }
348 };
349
350 Ok(augmented_bench)
351}