Skip to main content

telemetry_safe_tracing_macros/
lib.rs

1//! Proc macros for `telemetry-safe-tracing`.
2//!
3//! `safe_instrument` will live here so the public tracing crate can stay a
4//! normal library and still expose helper types alongside the attribute macro.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenTree;
8use quote::quote;
9use syn::parse::{Parse, ParseStream};
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{
13    Error, Expr, Ident, ItemFn, Result, ReturnType, Token, parenthesized, parse_macro_input,
14};
15
16#[proc_macro_attribute]
17pub fn safe_instrument(attr: TokenStream, item: TokenStream) -> TokenStream {
18    let args = parse_macro_input!(attr as InstrumentArgs);
19    let item_fn = parse_macro_input!(item as ItemFn);
20
21    match expand_safe_instrument(args, item_fn) {
22        Ok(tokens) => tokens.into(),
23        Err(err) => err.to_compile_error().into(),
24    }
25}
26
27fn expand_safe_instrument(
28    args: InstrumentArgs,
29    item_fn: ItemFn,
30) -> Result<proc_macro2::TokenStream> {
31    let config = args.expand()?;
32    let ItemFn {
33        attrs,
34        vis,
35        sig,
36        block,
37    } = item_fn;
38
39    if config.record_err && sig.output == ReturnType::Default {
40        return Err(Error::new(
41            sig.ident.span(),
42            "`err` requires a `Result`-returning function",
43        ));
44    }
45
46    let record_ret_enabled = config.record_ret;
47    let record_err_enabled = config.record_err;
48    let instrument_attr = config.instrument_attr();
49    let record_ret = if record_ret_enabled {
50        Some(quote! {
51            ::telemetry_safe_tracing::__private::record_ret(
52                &__telemetry_safe_span,
53                &__telemetry_safe_result,
54            );
55        })
56    } else {
57        None
58    };
59    let record_err = if record_err_enabled {
60        Some(quote! {
61            ::telemetry_safe_tracing::__private::record_err(
62                &__telemetry_safe_span,
63                &__telemetry_safe_result,
64            );
65        })
66    } else {
67        None
68    };
69
70    let body = if sig.asyncness.is_some() {
71        quote! {
72            let __telemetry_safe_span = ::telemetry_safe_tracing::tracing::Span::current();
73            let __telemetry_safe_result = (async move #block).await;
74            #record_ret
75            #record_err
76            __telemetry_safe_result
77        }
78    } else {
79        quote! {
80            let __telemetry_safe_span = ::telemetry_safe_tracing::tracing::Span::current();
81            let __telemetry_safe_result = (|| #block)();
82            #record_ret
83            #record_err
84            __telemetry_safe_result
85        }
86    };
87
88    Ok(quote! {
89        #(#attrs)*
90        #[::telemetry_safe_tracing::tracing::instrument(#instrument_attr)]
91        #vis #sig {
92            #body
93        }
94    })
95}
96
97struct InstrumentArgs {
98    args: Punctuated<InstrumentArg, Token![,]>,
99}
100
101impl Parse for InstrumentArgs {
102    fn parse(input: ParseStream<'_>) -> Result<Self> {
103        Ok(Self {
104            args: Punctuated::parse_terminated(input)?,
105        })
106    }
107}
108
109impl InstrumentArgs {
110    fn expand(self) -> Result<InstrumentConfig> {
111        // `instrument` defaults to recording every argument via `Debug`, which is
112        // exactly the ambient escape hatch this macro exists to remove.
113        let mut config = InstrumentConfig {
114            attr_args: vec![quote! { skip_all }],
115            field_args: Vec::new(),
116            record_ret: false,
117            record_err: false,
118        };
119        for arg in self.args {
120            arg.apply(&mut config)?;
121        }
122
123        Ok(config)
124    }
125}
126
127struct InstrumentConfig {
128    attr_args: Vec<proc_macro2::TokenStream>,
129    field_args: Vec<proc_macro2::TokenStream>,
130    record_ret: bool,
131    record_err: bool,
132}
133
134impl InstrumentConfig {
135    fn instrument_attr(mut self) -> proc_macro2::TokenStream {
136        if self.record_ret {
137            self.field_args
138                .push(quote! { ret = ::telemetry_safe_tracing::tracing::field::Empty });
139        }
140        if self.record_err {
141            self.field_args
142                .push(quote! { err = ::telemetry_safe_tracing::tracing::field::Empty });
143        }
144        if !self.field_args.is_empty() {
145            let field_args = self.field_args;
146            self.attr_args.push(quote! { fields(#(#field_args),*) });
147        }
148
149        let attr_args = self.attr_args;
150        quote! { #(#attr_args),* }
151    }
152}
153
154enum InstrumentArg {
155    Flag(Ident),
156    NameValue {
157        name: Ident,
158        value: Expr,
159    },
160    List {
161        name: Ident,
162        tokens: proc_macro2::TokenStream,
163    },
164}
165
166impl Parse for InstrumentArg {
167    fn parse(input: ParseStream<'_>) -> Result<Self> {
168        let name: Ident = input.parse()?;
169
170        if input.peek(syn::token::Paren) {
171            let content;
172            parenthesized!(content in input);
173            let tokens = content.parse()?;
174            return Ok(Self::List { name, tokens });
175        }
176
177        if input.peek(Token![=]) {
178            let _: Token![=] = input.parse()?;
179            let value: Expr = input.parse()?;
180            return Ok(Self::NameValue { name, value });
181        }
182
183        Ok(Self::Flag(name))
184    }
185}
186
187impl InstrumentArg {
188    fn apply(self, config: &mut InstrumentConfig) -> Result<()> {
189        match self {
190            Self::Flag(name) => match name.to_string().as_str() {
191                "skip_all" => Ok(()),
192                "ret" => {
193                    config.record_ret = true;
194                    Ok(())
195                }
196                "err" => {
197                    config.record_err = true;
198                    Ok(())
199                }
200                _ => Err(Error::new(
201                    name.span(),
202                    "unsupported safe_instrument flag; only `skip_all`, `skip(...)`, `name`, `level`, `target`, `ret`, `err`, and `fields(...)` are currently supported",
203                )),
204            },
205            Self::NameValue { name, value } => match name.to_string().as_str() {
206                "name" | "level" | "target" => {
207                    config.attr_args.push(quote! { #name = #value });
208                    Ok(())
209                }
210                _ => Err(Error::new(
211                    name.span(),
212                    "unsupported safe_instrument option; only `name`, `level`, `target`, `skip(...)`, `skip_all`, `ret`, `err`, and `fields(...)` are currently supported",
213                )),
214            },
215            Self::List { name, tokens } => match name.to_string().as_str() {
216                // `safe_instrument` already forces `skip_all`, so forwarding a
217                // partial skip list would only add confusing, redundant syntax.
218                "skip" => Ok(()),
219                "fields" => {
220                    let fields = syn::parse2::<FieldArgs>(tokens)?;
221                    config.field_args.extend(fields.expand()?);
222                    Ok(())
223                }
224                _ => Err(Error::new(
225                    name.span(),
226                    "unsupported safe_instrument list; only `skip(...)` and `fields(...)` are currently supported",
227                )),
228            },
229        }
230    }
231}
232
233struct FieldArgs {
234    fields: Punctuated<FieldArg, Token![,]>,
235}
236
237impl Parse for FieldArgs {
238    fn parse(input: ParseStream<'_>) -> Result<Self> {
239        Ok(Self {
240            fields: Punctuated::parse_terminated(input)?,
241        })
242    }
243}
244
245impl FieldArgs {
246    fn expand(self) -> Result<Vec<proc_macro2::TokenStream>> {
247        let mut expanded = Vec::with_capacity(self.fields.len());
248        for field in self.fields {
249            expanded.push(field.expand()?);
250        }
251
252        Ok(expanded)
253    }
254}
255
256struct FieldArg {
257    name: proc_macro2::TokenStream,
258    kind: FieldValueKind,
259}
260
261impl Parse for FieldArg {
262    fn parse(input: ParseStream<'_>) -> Result<Self> {
263        let mut name = proc_macro2::TokenStream::new();
264        while !input.peek(Token![=]) {
265            if input.is_empty() {
266                return Err(input.error("field entries must use `name = %expr`"));
267            }
268
269            let tt: TokenTree = input.parse()?;
270            name.extend(std::iter::once(tt));
271        }
272
273        let _: Token![=] = input.parse()?;
274
275        let kind = if input.peek(Token![%]) {
276            let _: Token![%] = input.parse()?;
277            FieldValueKind::Display(input.parse()?)
278        } else if input.peek(Token![?]) {
279            let mark: Token![?] = input.parse()?;
280            let _expr: Expr = input.parse()?;
281            return Err(Error::new(
282                mark.span,
283                "`?expr` is intentionally unsupported in safe_instrument; use `%expr` with a ToTelemetry value instead",
284            ));
285        } else {
286            let value: Expr = input.parse()?;
287            return Err(Error::new(
288                value.span(),
289                "field entries must use `%expr`; implicit value formatting is intentionally unsupported",
290            ));
291        };
292
293        Ok(Self { name, kind })
294    }
295}
296
297impl FieldArg {
298    fn expand(self) -> Result<proc_macro2::TokenStream> {
299        if self.name.is_empty() {
300            return Err(Error::new(
301                proc_macro2::Span::call_site(),
302                "field name cannot be empty",
303            ));
304        }
305
306        let name = self.name;
307        match self.kind {
308            FieldValueKind::Display(expr) => Ok(quote! {
309                #name = %::telemetry_safe_tracing::telemetry(&(#expr))
310            }),
311        }
312    }
313}
314
315enum FieldValueKind {
316    Display(Expr),
317}