1use proc_macro2::TokenStream;
18use quote::quote;
19use serde_tokenstream::from_tokenstream;
20use syn::spanned::Spanned;
21use usdt_impl::{CompileProvidersConfig, DataType, Probe, Provider};
22
23#[proc_macro_attribute]
25pub fn provider(
26 attr: proc_macro::TokenStream,
27 item: proc_macro::TokenStream,
28) -> proc_macro::TokenStream {
29 let attr = TokenStream::from(attr);
30 match from_tokenstream::<CompileProvidersConfig>(&attr) {
31 Ok(config) => {
32 if config.module.is_some() {
34 syn::Error::new(
35 attr.span(),
36 "The provider module may not be renamed via the attribute macro",
37 )
38 .to_compile_error()
39 .into()
40 } else {
41 generate_provider_item(TokenStream::from(item), config)
42 .unwrap_or_else(|e| e.to_compile_error())
43 .into()
44 }
45 }
46 Err(e) => e.to_compile_error().into(),
47 }
48}
49
50fn generate_provider_item(
52 item: TokenStream,
53 mut config: CompileProvidersConfig,
54) -> Result<TokenStream, syn::Error> {
55 let mod_ = syn::parse2::<syn::ItemMod>(item)?;
56 if mod_.ident == "provider" {
57 return Err(syn::Error::new(
58 mod_.ident.span(),
59 "Provider modules may not be named \"provider\"",
60 ));
61 }
62 let content = &mod_
63 .content
64 .as_ref()
65 .ok_or_else(|| {
66 syn::Error::new(mod_.span(), "Provider modules must have one or more probes")
67 })?
68 .1;
69
70 let mut check_fns = Vec::new();
71 let mut probes = Vec::new();
72 let mut use_statements = Vec::new();
73 for (fn_index, item) in content.iter().enumerate() {
74 match item {
75 syn::Item::Fn(ref func) => {
76 check_probe_name(&func.sig.ident)?;
77 let signature = check_probe_function_signature(&func.sig)?;
78 let mut item_check_fns = Vec::new();
79 let mut item_types = Vec::new();
80 for (arg_index, arg) in signature.inputs.iter().enumerate() {
81 match arg {
82 syn::FnArg::Receiver(item) => {
83 return Err(syn::Error::new(
84 item.span(),
85 "Probe functions may not take Self",
86 ));
87 }
88 syn::FnArg::Typed(ref item) => {
89 let (maybe_check_fn, item_type) =
90 parse_probe_argument(&item.ty, fn_index, arg_index)?;
91 if let Some(check_fn) = maybe_check_fn {
92 item_check_fns.push(check_fn);
93 }
94 item_types.push(item_type);
95 }
96 }
97 }
98 check_fns.extend(item_check_fns);
99 probes.push(Probe {
100 name: signature.ident.to_string(),
101 types: item_types,
102 });
103 }
104 syn::Item::Use(ref use_statement) => {
105 verify_use_tree(&use_statement.tree)?;
106 use_statements.push(use_statement.clone());
107 }
108 _ => {
109 return Err(syn::Error::new(
110 item.span(),
111 "Provider modules may only include empty functions or use statements",
112 ));
113 }
114 }
115 }
116
117 let name = match &config.provider {
122 Some(name) => {
123 let name = name.to_string();
124 config.module = Some(mod_.ident.to_string());
125 name
126 }
127 None => {
128 let name = mod_.ident.to_string();
129 config.provider = Some(name.clone());
130 config.module = Some(name.clone());
131 name
132 }
133 };
134
135 let provider = Provider {
136 name,
137 probes,
138 use_statements: use_statements.clone(),
139 };
140 let compiled = usdt_impl::compile_provider(&provider, &config);
141 let type_checks = if check_fns.is_empty() {
142 quote! { const _: fn() = || {}; }
143 } else {
144 quote! {
145 const _: fn() = || {
146 #(#use_statements)*
147 fn usdt_types_must_be_serialize<T: ?Sized + ::serde::Serialize>() {}
148 #(#check_fns)*
149 };
150 }
151 };
152 Ok(quote! {
153 #type_checks
154 #compiled
155 })
156}
157
158fn check_probe_name(ident: &syn::Ident) -> syn::Result<()> {
159 let check = |name| {
160 if ident == name {
161 Err(syn::Error::new(
162 ident.span(),
163 format!("Probe functions may not be named \"{}\"", name),
164 ))
165 } else {
166 Ok(())
167 }
168 };
169 check("probe").and(check("start"))
170}
171
172fn parse_probe_argument(
173 item: &syn::Type,
174 fn_index: usize,
175 arg_index: usize,
176) -> syn::Result<(Option<TokenStream>, DataType)> {
177 match item {
178 syn::Type::Path(ref path) => {
179 let last_ident = &path
180 .path
181 .segments
182 .last()
183 .ok_or_else(|| {
184 syn::Error::new(path.span(), "Probe arguments should resolve to path types")
185 })?
186 .ident;
187 if is_simple_type(last_ident) {
188 Ok((None, data_type_from_path(&path.path, false)))
189 } else if last_ident == "UniqueId" {
190 Ok((None, DataType::UniqueId))
191 } else {
192 let check_fn = build_serializable_check_function(item, fn_index, arg_index);
193 Ok((Some(check_fn), DataType::Serializable(item.clone())))
194 }
195 }
196 syn::Type::Ptr(ref pointer) => {
197 if pointer.mutability.is_some() {
198 return Err(syn::Error::new(item.span(), "Pointer types must be const"));
199 }
200 let ty = &*pointer.elem;
201 if let syn::Type::Path(ref path) = ty {
202 let last_ident = &path
203 .path
204 .segments
205 .last()
206 .ok_or_else(|| {
207 syn::Error::new(path.span(), "Probe arguments should resolve to path types")
208 })?
209 .ident;
210 if !is_integer_type(last_ident) {
211 return Err(syn::Error::new(
212 item.span(),
213 "Only pointers to integer types are supported",
214 ));
215 }
216 Ok((None, data_type_from_path(&path.path, true)))
217 } else {
218 Err(syn::Error::new(
219 item.span(),
220 "Only pointers to path types are supported",
221 ))
222 }
223 }
224 syn::Type::Reference(ref reference) => {
225 match parse_probe_argument(&reference.elem, fn_index, arg_index)? {
226 (None, DataType::UniqueId) => Ok((None, DataType::UniqueId)),
227 (None, DataType::Native(ty)) => Ok((None, DataType::Native(ty))),
228 _ => Ok((
229 Some(build_serializable_check_function(item, fn_index, arg_index)),
230 DataType::Serializable(item.clone()),
231 )),
232 }
233 }
234 syn::Type::Array(_) | syn::Type::Slice(_) | syn::Type::Tuple(_) => {
235 let check_fn = build_serializable_check_function(item, fn_index, arg_index);
236 Ok((Some(check_fn), DataType::Serializable(item.clone())))
237 }
238 _ => Err(syn::Error::new(
239 item.span(),
240 concat!(
241 "Probe arguments must be path types, slices, arrays, tuples, ",
242 "references, or const pointers to integers",
243 ),
244 )),
245 }
246}
247
248fn verify_use_tree(tree: &syn::UseTree) -> syn::Result<()> {
249 match tree {
250 syn::UseTree::Path(ref path) => {
251 if path.ident == "super" {
252 return Err(syn::Error::new(
253 path.span(),
254 concat!(
255 "Use-statements in USDT macros cannot contain relative imports (`super`), ",
256 "because the generated macros may be called from anywhere in a crate. ",
257 "Consider using `crate` instead.",
258 ),
259 ));
260 }
261 verify_use_tree(&path.tree)
262 }
263 _ => Ok(()),
264 }
265}
266
267fn build_serializable_check_function<T>(ident: &T, fn_index: usize, arg_index: usize) -> TokenStream
269where
270 T: quote::ToTokens,
271{
272 let fn_name = quote::format_ident!("usdt_types_must_be_serialize_{}_{}", fn_index, arg_index);
273 quote! {
274 fn #fn_name() {
275 usdt_types_must_be_serialize::<#ident>()
278 }
279 }
280}
281
282fn is_integer_type(ident: &syn::Ident) -> bool {
284 let ident = format!("{}", ident);
285 matches!(
286 ident.as_str(),
287 "u8" | "u16" | "u32" | "u64" | "i8" | "i16" | "i32" | "i64"
288 )
289}
290
291fn is_simple_type(ident: &syn::Ident) -> bool {
294 let ident = format!("{}", ident);
295 matches!(
296 ident.as_str(),
297 "u8" | "u16"
298 | "u32"
299 | "u64"
300 | "i8"
301 | "i16"
302 | "i32"
303 | "i64"
304 | "String"
305 | "str"
306 | "usize"
307 | "isize"
308 )
309}
310
311fn data_type_from_path(path: &syn::Path, pointer: bool) -> DataType {
313 use dtrace_parser::BitWidth;
314 use dtrace_parser::DataType as DType;
315 use dtrace_parser::Integer;
316 use dtrace_parser::Sign;
317
318 let variant = if pointer {
319 DType::Pointer
320 } else {
321 DType::Integer
322 };
323
324 if path.is_ident("u8") {
325 DataType::Native(variant(Integer {
326 sign: Sign::Unsigned,
327 width: BitWidth::Bit8,
328 }))
329 } else if path.is_ident("u16") {
330 DataType::Native(variant(Integer {
331 sign: Sign::Unsigned,
332 width: BitWidth::Bit16,
333 }))
334 } else if path.is_ident("u32") {
335 DataType::Native(variant(Integer {
336 sign: Sign::Unsigned,
337 width: BitWidth::Bit32,
338 }))
339 } else if path.is_ident("u64") {
340 DataType::Native(variant(Integer {
341 sign: Sign::Unsigned,
342 width: BitWidth::Bit64,
343 }))
344 } else if path.is_ident("i8") {
345 DataType::Native(variant(Integer {
346 sign: Sign::Signed,
347 width: BitWidth::Bit8,
348 }))
349 } else if path.is_ident("i16") {
350 DataType::Native(variant(Integer {
351 sign: Sign::Signed,
352 width: BitWidth::Bit16,
353 }))
354 } else if path.is_ident("i32") {
355 DataType::Native(variant(Integer {
356 sign: Sign::Signed,
357 width: BitWidth::Bit32,
358 }))
359 } else if path.is_ident("i64") {
360 DataType::Native(variant(Integer {
361 sign: Sign::Signed,
362 width: BitWidth::Bit64,
363 }))
364 } else if path.is_ident("String") || path.is_ident("str") {
365 DataType::Native(DType::String)
366 } else if path.is_ident("isize") {
367 DataType::Native(variant(Integer {
368 sign: Sign::Signed,
369 width: BitWidth::Pointer,
370 }))
371 } else if path.is_ident("usize") {
372 DataType::Native(variant(Integer {
373 sign: Sign::Unsigned,
374 width: BitWidth::Pointer,
375 }))
376 } else {
377 unreachable!("Tried to parse a non-path data type");
378 }
379}
380
381fn check_probe_function_signature(
383 signature: &syn::Signature,
384) -> Result<&syn::Signature, syn::Error> {
385 let to_err = |span, msg| Err(syn::Error::new(span, msg));
386 if let Some(item) = signature.unsafety {
387 return to_err(item.span(), "Probe functions may not be unsafe");
388 }
389 if let Some(ref item) = signature.abi {
390 return to_err(item.span(), "Probe functions may not specify an ABI");
391 }
392 if let Some(ref item) = signature.asyncness {
393 return to_err(item.span(), "Probe functions may not be async");
394 }
395 if !signature.generics.params.is_empty() {
396 return to_err(
397 signature.generics.span(),
398 "Probe functions may not be generic",
399 );
400 }
401 if !matches!(signature.output, syn::ReturnType::Default) {
402 return to_err(
403 signature.output.span(),
404 "Probe functions may not specify a return type",
405 );
406 }
407 Ok(signature)
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use dtrace_parser::BitWidth;
414 use dtrace_parser::DataType as DType;
415 use dtrace_parser::Integer;
416 use dtrace_parser::Sign;
417 use rstest::rstest;
418
419 #[test]
420 fn test_is_simple_type() {
421 assert!(is_simple_type("e::format_ident!("u8")));
422 assert!(!is_simple_type("e::format_ident!("Foo")));
423 }
424
425 #[test]
426 fn test_data_type_from_path() {
427 assert_eq!(
428 data_type_from_path(&syn::parse_str("u8").unwrap(), false),
429 DataType::Native(DType::Integer(Integer {
430 sign: Sign::Unsigned,
431 width: BitWidth::Bit8,
432 })),
433 );
434 assert_eq!(
435 data_type_from_path(&syn::parse_str("u8").unwrap(), true),
436 DataType::Native(DType::Pointer(Integer {
437 sign: Sign::Unsigned,
438 width: BitWidth::Bit8,
439 })),
440 );
441 assert_eq!(
442 data_type_from_path(&syn::parse_str("String").unwrap(), false),
443 DataType::Native(DType::String),
444 );
445 assert_eq!(
446 data_type_from_path(&syn::parse_str("String").unwrap(), false),
447 DataType::Native(DType::String),
448 );
449 }
450
451 #[test]
452 #[should_panic]
453 fn test_data_type_from_path_panics() {
454 data_type_from_path(&syn::parse_str("std::net::IpAddr").unwrap(), false);
455 }
456
457 #[rstest]
458 #[case("u8", DType::Integer(Integer { sign: Sign::Unsigned, width: BitWidth::Bit8 }))]
459 #[case("*const u8", DType::Pointer(Integer { sign: Sign::Unsigned, width: BitWidth::Bit8}))]
460 #[case("&u8", DType::Integer(Integer { sign: Sign::Unsigned, width: BitWidth::Bit8 }))]
461 #[case("&str", DType::String)]
462 #[case("String", DType::String)]
463 #[case("&&str", DType::String)]
464 #[case("&String", DType::String)]
465 fn test_parse_probe_argument_native(#[case] name: &str, #[case] ty: dtrace_parser::DataType) {
466 let arg = syn::parse_str(name).unwrap();
467 let out = parse_probe_argument(&arg, 0, 0).unwrap();
468 assert!(out.0.is_none());
469 assert_eq!(out.1, DataType::Native(ty));
470 }
471
472 #[rstest]
473 #[case("usdt::UniqueId")]
474 #[case("&usdt::UniqueId")]
475 fn test_parse_probe_argument_span(#[case] arg: &str) {
476 let ty = syn::parse_str(arg).unwrap();
477 let out = parse_probe_argument(&ty, 0, 0).unwrap();
478 assert!(out.0.is_none());
479 assert_eq!(out.1, DataType::UniqueId)
480 }
481
482 #[rstest]
483 #[case("std::net::IpAddr")]
484 #[case("&std::net::IpAddr")]
485 #[case("&SomeType")]
486 #[case("&&[u8]")]
487 fn test_parse_probe_argument_serializable(#[case] name: &str) {
488 let ty = syn::parse_str(name).unwrap();
489 let out = parse_probe_argument(&ty, 0, 0).unwrap();
490 assert!(out.0.is_some());
491 assert_eq!(out.1, DataType::Serializable(ty));
492 if let (Some(chk), DataType::Serializable(ty)) = out {
493 println!("{}", quote! { #chk });
494 println!("{}", quote! { #ty });
495 }
496 }
497
498 #[test]
499 fn test_check_probe_function_signature() {
500 let signature = syn::parse_str::<syn::Signature>("fn foo(_: u8)").unwrap();
501 assert!(check_probe_function_signature(&signature).is_ok());
502
503 let check_is_err = |s| {
504 let signature = syn::parse_str::<syn::Signature>(s).unwrap();
505 assert!(check_probe_function_signature(&signature).is_err());
506 };
507 check_is_err("unsafe fn foo(_: u8)");
508 check_is_err(r#"extern "C" fn foo(_: u8)"#);
509 check_is_err("fn foo<T: Debug>(_: u8)");
510 check_is_err("fn foo(_: u8) -> u8");
511 }
512
513 #[test]
514 fn test_verify_use_tree() {
515 let tokens = quote! { use std::net::IpAddr; };
516 let item: syn::ItemUse = syn::parse2(tokens).unwrap();
517 assert!(verify_use_tree(&item.tree).is_ok());
518
519 let tokens = quote! { use super::SomeType; };
520 let item: syn::ItemUse = syn::parse2(tokens).unwrap();
521 assert!(verify_use_tree(&item.tree).is_err());
522
523 let tokens = quote! { use crate::super::SomeType; };
524 let item: syn::ItemUse = syn::parse2(tokens).unwrap();
525 assert!(verify_use_tree(&item.tree).is_err());
526 }
527}