usdt_attr_macro/
lib.rs

1//! Generate USDT probes from an attribute macro
2
3// Copyright 2024 Oxide Computer Company
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17use proc_macro2::TokenStream;
18use quote::quote;
19use serde_tokenstream::from_tokenstream;
20use syn::spanned::Spanned;
21use usdt_impl::{CompileProvidersConfig, DataType, Probe, Provider};
22
23/// Generate a provider from functions defined in a Rust module.
24#[proc_macro_attribute]
25pub fn provider(
26    attr: proc_macro::TokenStream,
27    item: proc_macro::TokenStream,
28) -> proc_macro::TokenStream {
29    let attr = TokenStream::from(attr);
30    match from_tokenstream::<CompileProvidersConfig>(&attr) {
31        Ok(config) => {
32            // Renaming the module via the attribute macro isn't supported.
33            if config.module.is_some() {
34                syn::Error::new(
35                    attr.span(),
36                    "The provider module may not be renamed via the attribute macro",
37                )
38                .to_compile_error()
39                .into()
40            } else {
41                generate_provider_item(TokenStream::from(item), config)
42                    .unwrap_or_else(|e| e.to_compile_error())
43                    .into()
44            }
45        }
46        Err(e) => e.to_compile_error().into(),
47    }
48}
49
50// Generate the actual provider implementation, include the type-checks and probe macros.
51fn generate_provider_item(
52    item: TokenStream,
53    mut config: CompileProvidersConfig,
54) -> Result<TokenStream, syn::Error> {
55    let mod_ = syn::parse2::<syn::ItemMod>(item)?;
56    if mod_.ident == "provider" {
57        return Err(syn::Error::new(
58            mod_.ident.span(),
59            "Provider modules may not be named \"provider\"",
60        ));
61    }
62    let content = &mod_
63        .content
64        .as_ref()
65        .ok_or_else(|| {
66            syn::Error::new(mod_.span(), "Provider modules must have one or more probes")
67        })?
68        .1;
69
70    let mut check_fns = Vec::new();
71    let mut probes = Vec::new();
72    let mut use_statements = Vec::new();
73    for (fn_index, item) in content.iter().enumerate() {
74        match item {
75            syn::Item::Fn(ref func) => {
76                check_probe_name(&func.sig.ident)?;
77                let signature = check_probe_function_signature(&func.sig)?;
78                let mut item_check_fns = Vec::new();
79                let mut item_types = Vec::new();
80                for (arg_index, arg) in signature.inputs.iter().enumerate() {
81                    match arg {
82                        syn::FnArg::Receiver(item) => {
83                            return Err(syn::Error::new(
84                                item.span(),
85                                "Probe functions may not take Self",
86                            ));
87                        }
88                        syn::FnArg::Typed(ref item) => {
89                            let (maybe_check_fn, item_type) =
90                                parse_probe_argument(&item.ty, fn_index, arg_index)?;
91                            if let Some(check_fn) = maybe_check_fn {
92                                item_check_fns.push(check_fn);
93                            }
94                            item_types.push(item_type);
95                        }
96                    }
97                }
98                check_fns.extend(item_check_fns);
99                probes.push(Probe {
100                    name: signature.ident.to_string(),
101                    types: item_types,
102                });
103            }
104            syn::Item::Use(ref use_statement) => {
105                verify_use_tree(&use_statement.tree)?;
106                use_statements.push(use_statement.clone());
107            }
108            _ => {
109                return Err(syn::Error::new(
110                    item.span(),
111                    "Provider modules may only include empty functions or use statements",
112                ));
113            }
114        }
115    }
116
117    // We're guaranteed that the module name in the config is None. If the user has set the
118    // provider name there, extract it. If they have _not_ set the provider name there, extract the
119    // module name. In both cases, we don't support renaming the module via this path, so the
120    // module name is passed through.
121    let name = match &config.provider {
122        Some(name) => {
123            let name = name.to_string();
124            config.module = Some(mod_.ident.to_string());
125            name
126        }
127        None => {
128            let name = mod_.ident.to_string();
129            config.provider = Some(name.clone());
130            config.module = Some(name.clone());
131            name
132        }
133    };
134
135    let provider = Provider {
136        name,
137        probes,
138        use_statements: use_statements.clone(),
139    };
140    let compiled = usdt_impl::compile_provider(&provider, &config);
141    let type_checks = if check_fns.is_empty() {
142        quote! { const _: fn() = || {}; }
143    } else {
144        quote! {
145            const _: fn() = || {
146                #(#use_statements)*
147                fn usdt_types_must_be_serialize<T: ?Sized + ::serde::Serialize>() {}
148                #(#check_fns)*
149            };
150        }
151    };
152    Ok(quote! {
153        #type_checks
154        #compiled
155    })
156}
157
158fn check_probe_name(ident: &syn::Ident) -> syn::Result<()> {
159    let check = |name| {
160        if ident == name {
161            Err(syn::Error::new(
162                ident.span(),
163                format!("Probe functions may not be named \"{}\"", name),
164            ))
165        } else {
166            Ok(())
167        }
168    };
169    check("probe").and(check("start"))
170}
171
172fn parse_probe_argument(
173    item: &syn::Type,
174    fn_index: usize,
175    arg_index: usize,
176) -> syn::Result<(Option<TokenStream>, DataType)> {
177    match item {
178        syn::Type::Path(ref path) => {
179            let last_ident = &path
180                .path
181                .segments
182                .last()
183                .ok_or_else(|| {
184                    syn::Error::new(path.span(), "Probe arguments should resolve to path types")
185                })?
186                .ident;
187            if is_simple_type(last_ident) {
188                Ok((None, data_type_from_path(&path.path, false)))
189            } else if last_ident == "UniqueId" {
190                Ok((None, DataType::UniqueId))
191            } else {
192                let check_fn = build_serializable_check_function(item, fn_index, arg_index);
193                Ok((Some(check_fn), DataType::Serializable(item.clone())))
194            }
195        }
196        syn::Type::Ptr(ref pointer) => {
197            if pointer.mutability.is_some() {
198                return Err(syn::Error::new(item.span(), "Pointer types must be const"));
199            }
200            let ty = &*pointer.elem;
201            if let syn::Type::Path(ref path) = ty {
202                let last_ident = &path
203                    .path
204                    .segments
205                    .last()
206                    .ok_or_else(|| {
207                        syn::Error::new(path.span(), "Probe arguments should resolve to path types")
208                    })?
209                    .ident;
210                if !is_integer_type(last_ident) {
211                    return Err(syn::Error::new(
212                        item.span(),
213                        "Only pointers to integer types are supported",
214                    ));
215                }
216                Ok((None, data_type_from_path(&path.path, true)))
217            } else {
218                Err(syn::Error::new(
219                    item.span(),
220                    "Only pointers to path types are supported",
221                ))
222            }
223        }
224        syn::Type::Reference(ref reference) => {
225            match parse_probe_argument(&reference.elem, fn_index, arg_index)? {
226                (None, DataType::UniqueId) => Ok((None, DataType::UniqueId)),
227                (None, DataType::Native(ty)) => Ok((None, DataType::Native(ty))),
228                _ => Ok((
229                    Some(build_serializable_check_function(item, fn_index, arg_index)),
230                    DataType::Serializable(item.clone()),
231                )),
232            }
233        }
234        syn::Type::Array(_) | syn::Type::Slice(_) | syn::Type::Tuple(_) => {
235            let check_fn = build_serializable_check_function(item, fn_index, arg_index);
236            Ok((Some(check_fn), DataType::Serializable(item.clone())))
237        }
238        _ => Err(syn::Error::new(
239            item.span(),
240            concat!(
241                "Probe arguments must be path types, slices, arrays, tuples, ",
242                "references, or const pointers to integers",
243            ),
244        )),
245    }
246}
247
248fn verify_use_tree(tree: &syn::UseTree) -> syn::Result<()> {
249    match tree {
250        syn::UseTree::Path(ref path) => {
251            if path.ident == "super" {
252                return Err(syn::Error::new(
253                    path.span(),
254                    concat!(
255                        "Use-statements in USDT macros cannot contain relative imports (`super`), ",
256                        "because the generated macros may be called from anywhere in a crate. ",
257                        "Consider using `crate` instead.",
258                    ),
259                ));
260            }
261            verify_use_tree(&path.tree)
262        }
263        _ => Ok(()),
264    }
265}
266
267// Create a function that statically asserts the given identifier implements `Serialize`.
268fn build_serializable_check_function<T>(ident: &T, fn_index: usize, arg_index: usize) -> TokenStream
269where
270    T: quote::ToTokens,
271{
272    let fn_name = quote::format_ident!("usdt_types_must_be_serialize_{}_{}", fn_index, arg_index);
273    quote! {
274        fn #fn_name() {
275            // #ident must be in scope here, because this function is defined in the same module as
276            // the actual probe functions, and thus shares any imports the consumer wants.
277            usdt_types_must_be_serialize::<#ident>()
278        }
279    }
280}
281
282// Return `true` if the type is an integer
283fn is_integer_type(ident: &syn::Ident) -> bool {
284    let ident = format!("{}", ident);
285    matches!(
286        ident.as_str(),
287        "u8" | "u16" | "u32" | "u64" | "i8" | "i16" | "i32" | "i64"
288    )
289}
290
291// Return `true` if this type is "simple", a primitive type with an analog in D, i.e., _not_ a
292// type that implements `Serialize`.
293fn is_simple_type(ident: &syn::Ident) -> bool {
294    let ident = format!("{}", ident);
295    matches!(
296        ident.as_str(),
297        "u8" | "u16"
298            | "u32"
299            | "u64"
300            | "i8"
301            | "i16"
302            | "i32"
303            | "i64"
304            | "String"
305            | "str"
306            | "usize"
307            | "isize"
308    )
309}
310
311// Return the `dtrace_parser::DataType` corresponding to the given `path`
312fn data_type_from_path(path: &syn::Path, pointer: bool) -> DataType {
313    use dtrace_parser::BitWidth;
314    use dtrace_parser::DataType as DType;
315    use dtrace_parser::Integer;
316    use dtrace_parser::Sign;
317
318    let variant = if pointer {
319        DType::Pointer
320    } else {
321        DType::Integer
322    };
323
324    if path.is_ident("u8") {
325        DataType::Native(variant(Integer {
326            sign: Sign::Unsigned,
327            width: BitWidth::Bit8,
328        }))
329    } else if path.is_ident("u16") {
330        DataType::Native(variant(Integer {
331            sign: Sign::Unsigned,
332            width: BitWidth::Bit16,
333        }))
334    } else if path.is_ident("u32") {
335        DataType::Native(variant(Integer {
336            sign: Sign::Unsigned,
337            width: BitWidth::Bit32,
338        }))
339    } else if path.is_ident("u64") {
340        DataType::Native(variant(Integer {
341            sign: Sign::Unsigned,
342            width: BitWidth::Bit64,
343        }))
344    } else if path.is_ident("i8") {
345        DataType::Native(variant(Integer {
346            sign: Sign::Signed,
347            width: BitWidth::Bit8,
348        }))
349    } else if path.is_ident("i16") {
350        DataType::Native(variant(Integer {
351            sign: Sign::Signed,
352            width: BitWidth::Bit16,
353        }))
354    } else if path.is_ident("i32") {
355        DataType::Native(variant(Integer {
356            sign: Sign::Signed,
357            width: BitWidth::Bit32,
358        }))
359    } else if path.is_ident("i64") {
360        DataType::Native(variant(Integer {
361            sign: Sign::Signed,
362            width: BitWidth::Bit64,
363        }))
364    } else if path.is_ident("String") || path.is_ident("str") {
365        DataType::Native(DType::String)
366    } else if path.is_ident("isize") {
367        DataType::Native(variant(Integer {
368            sign: Sign::Signed,
369            width: BitWidth::Pointer,
370        }))
371    } else if path.is_ident("usize") {
372        DataType::Native(variant(Integer {
373            sign: Sign::Unsigned,
374            width: BitWidth::Pointer,
375        }))
376    } else {
377        unreachable!("Tried to parse a non-path data type");
378    }
379}
380
381// Sanity checks on a probe function signature.
382fn check_probe_function_signature(
383    signature: &syn::Signature,
384) -> Result<&syn::Signature, syn::Error> {
385    let to_err = |span, msg| Err(syn::Error::new(span, msg));
386    if let Some(item) = signature.unsafety {
387        return to_err(item.span(), "Probe functions may not be unsafe");
388    }
389    if let Some(ref item) = signature.abi {
390        return to_err(item.span(), "Probe functions may not specify an ABI");
391    }
392    if let Some(ref item) = signature.asyncness {
393        return to_err(item.span(), "Probe functions may not be async");
394    }
395    if !signature.generics.params.is_empty() {
396        return to_err(
397            signature.generics.span(),
398            "Probe functions may not be generic",
399        );
400    }
401    if !matches!(signature.output, syn::ReturnType::Default) {
402        return to_err(
403            signature.output.span(),
404            "Probe functions may not specify a return type",
405        );
406    }
407    Ok(signature)
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use dtrace_parser::BitWidth;
414    use dtrace_parser::DataType as DType;
415    use dtrace_parser::Integer;
416    use dtrace_parser::Sign;
417    use rstest::rstest;
418
419    #[test]
420    fn test_is_simple_type() {
421        assert!(is_simple_type(&quote::format_ident!("u8")));
422        assert!(!is_simple_type(&quote::format_ident!("Foo")));
423    }
424
425    #[test]
426    fn test_data_type_from_path() {
427        assert_eq!(
428            data_type_from_path(&syn::parse_str("u8").unwrap(), false),
429            DataType::Native(DType::Integer(Integer {
430                sign: Sign::Unsigned,
431                width: BitWidth::Bit8,
432            })),
433        );
434        assert_eq!(
435            data_type_from_path(&syn::parse_str("u8").unwrap(), true),
436            DataType::Native(DType::Pointer(Integer {
437                sign: Sign::Unsigned,
438                width: BitWidth::Bit8,
439            })),
440        );
441        assert_eq!(
442            data_type_from_path(&syn::parse_str("String").unwrap(), false),
443            DataType::Native(DType::String),
444        );
445        assert_eq!(
446            data_type_from_path(&syn::parse_str("String").unwrap(), false),
447            DataType::Native(DType::String),
448        );
449    }
450
451    #[test]
452    #[should_panic]
453    fn test_data_type_from_path_panics() {
454        data_type_from_path(&syn::parse_str("std::net::IpAddr").unwrap(), false);
455    }
456
457    #[rstest]
458    #[case("u8", DType::Integer(Integer { sign: Sign::Unsigned, width: BitWidth::Bit8 }))]
459    #[case("*const u8", DType::Pointer(Integer { sign: Sign::Unsigned, width: BitWidth::Bit8}))]
460    #[case("&u8", DType::Integer(Integer { sign: Sign::Unsigned, width: BitWidth::Bit8 }))]
461    #[case("&str", DType::String)]
462    #[case("String", DType::String)]
463    #[case("&&str", DType::String)]
464    #[case("&String", DType::String)]
465    fn test_parse_probe_argument_native(#[case] name: &str, #[case] ty: dtrace_parser::DataType) {
466        let arg = syn::parse_str(name).unwrap();
467        let out = parse_probe_argument(&arg, 0, 0).unwrap();
468        assert!(out.0.is_none());
469        assert_eq!(out.1, DataType::Native(ty));
470    }
471
472    #[rstest]
473    #[case("usdt::UniqueId")]
474    #[case("&usdt::UniqueId")]
475    fn test_parse_probe_argument_span(#[case] arg: &str) {
476        let ty = syn::parse_str(arg).unwrap();
477        let out = parse_probe_argument(&ty, 0, 0).unwrap();
478        assert!(out.0.is_none());
479        assert_eq!(out.1, DataType::UniqueId)
480    }
481
482    #[rstest]
483    #[case("std::net::IpAddr")]
484    #[case("&std::net::IpAddr")]
485    #[case("&SomeType")]
486    #[case("&&[u8]")]
487    fn test_parse_probe_argument_serializable(#[case] name: &str) {
488        let ty = syn::parse_str(name).unwrap();
489        let out = parse_probe_argument(&ty, 0, 0).unwrap();
490        assert!(out.0.is_some());
491        assert_eq!(out.1, DataType::Serializable(ty));
492        if let (Some(chk), DataType::Serializable(ty)) = out {
493            println!("{}", quote! { #chk });
494            println!("{}", quote! { #ty });
495        }
496    }
497
498    #[test]
499    fn test_check_probe_function_signature() {
500        let signature = syn::parse_str::<syn::Signature>("fn foo(_: u8)").unwrap();
501        assert!(check_probe_function_signature(&signature).is_ok());
502
503        let check_is_err = |s| {
504            let signature = syn::parse_str::<syn::Signature>(s).unwrap();
505            assert!(check_probe_function_signature(&signature).is_err());
506        };
507        check_is_err("unsafe fn foo(_: u8)");
508        check_is_err(r#"extern "C" fn foo(_: u8)"#);
509        check_is_err("fn foo<T: Debug>(_: u8)");
510        check_is_err("fn foo(_: u8) -> u8");
511    }
512
513    #[test]
514    fn test_verify_use_tree() {
515        let tokens = quote! { use std::net::IpAddr; };
516        let item: syn::ItemUse = syn::parse2(tokens).unwrap();
517        assert!(verify_use_tree(&item.tree).is_ok());
518
519        let tokens = quote! { use super::SomeType; };
520        let item: syn::ItemUse = syn::parse2(tokens).unwrap();
521        assert!(verify_use_tree(&item.tree).is_err());
522
523        let tokens = quote! { use crate::super::SomeType; };
524        let item: syn::ItemUse = syn::parse2(tokens).unwrap();
525        assert!(verify_use_tree(&item.tree).is_err());
526    }
527}